4.2.3.5. 基于 FX 的量化

FX 是 torch 处理计算图的一套机制,由于计算图的存在,基于 FX 的量化相比 eager mode 存在以下优势: - 可以实现 fuse pattern 的自动化匹配,用户不再需要手动执行 fuse 过程 - 为其他基于计算图的优化提供了可能性

总体上看,FX 量化的流程与 eager mode 高度一致,一般仅需将接口替换即可: - prepare_calibration -> prepare_calibration_fx - prepare_qat -> prepare_qat_fx - convert -> convert_fx

其中 prepare_calibration_fxprepare_qat_fx 接口中都集成了自动化的 fuse 流程。除此之外,我们也提供了单独的 fuse_fx 接口供 debug 或研究使用

4.2.3.5.1. 限制

FX 采用符号执行的方式对模型中的操作进行记录,存在诸多限制,具体见官方文档相关章节 limitations-of-symbolic-tracing

用户模型中若用到了 FX 不支持的操作如控制流等,可以通过 wrap 的方式将其作为一个函数或者方法包装为一个整体,FX 将不再关注它们的内部逻辑,而是将对它们的调用原样保留

我们对 torch.fx.wrap 进行了扩展以支持更多的包装形式,具体说明见 utils.fx_helper.wrap 的接口文档

下面举例对 wrap 的使用进行说明

[3]:
from torch import nn
import torch
from torch.nn import functional as F
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.quantize_fx import QuantizationTracer
from torch.quantization import DeQuantStub

from horizon_plugin_pytorch.quantization.fx.graph_module import GraphModuleWithAttr



class FxWrapExampleNet(nn.Module):
    def __init__(self):
        super(FxWrapExampleNet, self).__init__()
        self.quant = QuantStub()
        self.conv = nn.Conv2d(3, 3, 1)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.dequant = DeQuantStub()

    def forward(self, input):
        # 模型主体,需要进行 fuse、量化等操作
        x = self.quant(input)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)

        # 后处理,不需要量化,且包含条件分支
        if self.training:
            print("Run softmax")
            return F.softmax(x, dim=1)
        else:
            print("Run argmax")
            return torch.argmax(x, dim=1)

model = FxWrapExampleNet()
tracer = QuantizationTracer([], [])
graph = tracer.trace(model)
graph.print_tabular()

graph_model = GraphModuleWithAttr(model, graph)
print(graph_model.code)

data = torch.rand(1, 3, 64, 64)
ret = graph_model(data)
Run softmax
opcode         name     target                                args        kwargs
-------------  -------  ------------------------------------  ----------  -------------------------------------------
placeholder    input_1  input                                 ()          {}
call_module    conv     conv                                  (input_1,)  {}
call_module    bn       bn                                    (conv,)     {}
call_module    relu     relu                                  (bn,)       {}
call_function  softmax  <function softmax at 0x7f0ac35be040>  (relu,)     {'dim': 1, '_stacklevel': 3, 'dtype': None}
output         output   output                                (softmax,)  {}



def forward(self, input):
    input_1 = input
    conv = self.conv(input_1);  input_1 = None
    bn = self.bn(conv);  conv = None
    relu = self.relu(bn);  bn = None
    softmax = torch.nn.functional.softmax(relu, dim = 1, _stacklevel = 3, dtype = None);  relu = None
    return softmax

可以看到,trace 后的模型中后处理 torch.argmax 以及 print 语句都被丢弃了,如果需要将后处理原样保留,可以将它包装起来

[4]:
from horizon_plugin_pytorch.utils.fx_helper import wrap as fx_wrap

class FxWrapExampleNet(nn.Module):
    def __init__(self):
        super(FxWrapExampleNet, self).__init__()
        self.quant = QuantStub()
        self.conv = nn.Conv2d(3, 3, 1)
        self.bn = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
        self.dequant = DeQuantStub()

    # 将后处理包装为一个 method
    @fx_wrap
    def _post_process(self, model_output):
        if self.training:
            print("Run softmax")
            return F.softmax(model_output, dim=1)
        else:
            print("Run argmax")
            return torch.argmax(model_output, dim=1)

    def forward(self, input):
        x = self.quant(input)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)

        return self._post_process(x)

model = FxWrapExampleNet()
tracer = QuantizationTracer([], [])
graph = tracer.trace(model)
graph.print_tabular()

graph_model = GraphModuleWithAttr(model, graph)
print(graph_model.code)

ret = graph_model(data)
ret = graph_model.eval()(data)
opcode       name           target         args              kwargs
-----------  -------------  -------------  ----------------  --------
placeholder  input_1        input          ()                {}
call_module  conv           conv           (input_1,)        {}
call_module  bn             bn             (conv,)           {}
call_module  relu           relu           (bn,)             {}
get_attr     _self          _self          ()                {}
call_method  _post_process  _post_process  (_self, relu)     {}
output       output         output         (_post_process,)  {}



def forward(self, input):
    input_1 = input
    conv = self.conv(input_1);  input_1 = None
    bn = self.bn(conv);  conv = None
    relu = self.relu(bn);  bn = None
    _self = self._self
    _post_process = _self._post_process(relu);  _self = relu = None
    return _post_process

Run softmax
Run argmax

可以看到,包装起来的后处理在 trace 后的模型中作为一个整体被调用,内部逻辑将被原样保留, print 语句也可以正常打印内容