7.5.1. FX Quantization 原理介绍¶
阅读此文档前,建议先阅读 torch.fx — PyTorch documentation,以对 torch 的 FX 机制有初步的了解。
FX 采用符号执行的方式,可以在 nn.Module
或 function 的层面对模型建图,从而实现自动化的 fuse 以及其他基于图的优化。
7.5.1.1. 量化流程¶
7.5.1.1.1. Fuse(可选)¶
FX 可以感知计算图,所以可以实现自动化的算子融合,用户不再需要手动指定需要融合的算子,直接调用接口即可。
fused_model = horizon.quantization.fuse_fx(model)
注意
fuse_fx
没有inplace
参数,因为内部需要对模型做 symbolic trace 生成一个GraphModule
,所以无法做到 inplace 的修改fused_model
和model
会共享几乎所有属性(包括子模块、算子等),因此在 fuse 之后请不要对model
做任何修改,否则可能影响到fused_model
用户不必显式调用
fuse_fx
接口,因为后续的prepare_qat_fx
接口内部集成了 fuse 的过程
7.5.1.1.2. Prepare¶
用户在调用 prepare_qat_fx
接口之前必须根据目标硬件平台设置全局的 march。接口内部会先执行 fuse 过程(即使模型已经 fuse 过了),再将模型中符合条件的算子替换为 horizon.nn.qat
中的实现。
用户可以根据需要选择合适的 qconfig(Calibtaion 或 QAT,注意两种 qconfig 不能混用)
和
fuse_fx
类似,此接口不支持inplace
参数,且在prepare_qat_fx
之后请不要对输入的模型做任何修改
horizon.march.set_march(March.XXX)
qat_model = horizon.quantization.prepare_qat_fx(
model,
{
"": horizon.qconfig.default_calib_8bit_fake_quant_qconfig,
"module_name": {
"<module_name>": custom_qconfig,
},
},)
7.5.1.1.3. Convert¶
和
fuse_fx
类似,此接口不支持inplace
参数,且在convert_fx
之后请不要对输入的模型做任何修改
quantized_model = horizon.quantization.convert_fx(qat_model)
7.5.1.1.4. Eager Mode 兼容性¶
大部分情况下,FX 量化的接口可以直接替换 eager mode 量化的接口(prepare_qat
-> prepare_qat_fx
, convert
-> convert_fx
),但是不能和 eager mode 的接口混用。部分模型在以下情况下需要对代码结构做一定的修改。
FX 不支持的操作:torch 的 symbolic trace 支持的操作是有限的,例如不支持将非静态变量作为判断条件、默认不支持 torch 以外的 pkg(如 numpy)等,且未执行到的条件分支将被丢弃
不想被 FX 处理的操作:如果模型的前后处理中使用了 torch 的 op,FX 在 trace 时会将他们视为模型的一部分,产生不符合预期的行为(例如将 torch 的某些 function 调用替换为 FloatFunctional)。
以上两种情况,都可以采用 wrap 的方法来避免,下面以 RetinaNet 为例进行说明。
from horizon_plugin_pytorch.fx.fx_helper import wrap as fx_wrap
class RetinaNet(nn.Module):
def __init__(
self,
backbone: nn.Module,
neck: Optional[nn.Module] = None,
head: Optional[nn.Module] = None,
anchors: Optional[nn.Module] = None,
targets: Optional[nn.Module] = None,
post_process: Optional[nn.Module] = None,
loss_cls: Optional[nn.Module] = None,
loss_reg: Optional[nn.Module] = None,
):
super(RetinaNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.anchors = anchors
self.targets = targets
self.post_process = post_process
self.loss_cls = loss_cls
self.loss_reg = loss_reg
def rearrange_head_out(self, inputs: List[torch.Tensor], num: int):
outputs = []
for t in inputs:
outputs.append(t.permute(0, 2, 3, 1).reshape(t.shape[0], -1, num))
return torch.cat(outputs, dim=1)
def forward(self, data: Dict):
feat = self.backbone(data["img"])
feat = self.neck(feat) if self.neck else feat
cls_scores, bbox_preds = self.head(feat)
if self.post_process is None:
return cls_scores, bbox_preds
# 将不需要建图的操作封装为一个 method 即可,FX 将不再关注 method 内部的逻辑,
# 仅将它原样保留(method 中调用的 module 仍可被设置 qconfig,被
# prepare_qat_fx 和 convert_fx 替换)
return self._post_process( data, feat, cls_scores, bbox_preds)
@fx_wrap() # fx_wrap 支持直接装饰 class method
def _post_process(self, data, feat, cls_scores, bbox_preds)
anchors = self.anchors(feat)
# 对 self.training 的判断必须封装起来,否则在 symbolic trace 之后,此判断
# 逻辑会被丢掉
if self.training:
cls_scores = self.rearrange_head_out(
cls_scores, self.head.num_classes
)
bbox_preds = self.rearrange_head_out(bbox_preds, 4)
gt_labels = [
torch.cat(
[data["gt_bboxes"][i], data["gt_classes"][i][:, None] + 1],
dim=-1,
)
for i in range(len(data["gt_classes"]))
]
gt_labels = [gt_label.float() for gt_label in gt_labels]
_, labels = self.targets(anchors, gt_labels)
avg_factor = labels["reg_label_mask"].sum()
if avg_factor == 0:
avg_factor += 1
cls_loss = self.loss_cls(
pred=cls_scores.sigmoid(),
target=labels["cls_label"],
weight=labels["cls_label_mask"],
avg_factor=avg_factor,
)
reg_loss = self.loss_reg(
pred=bbox_preds,
target=labels["reg_label"],
weight=labels["reg_label_mask"],
avg_factor=avg_factor,
)
return {
"cls_loss": cls_loss,
"reg_loss": reg_loss,
}
else:
preds = self.post_process(
anchors,
cls_scores,
bbox_preds,
[torch.tensor(shape) for shape in data["resized_shape"]],
)
assert (
"pred_bboxes" not in data.keys()
), "pred_bboxes has been in data.keys()"
data["pred_bboxes"] = preds
return data