4.2.3.5. 基于 FX 的量化¶
FX 是 torch 处理计算图的一套机制,由于计算图的存在,基于 FX 的量化相比 eager mode 存在以下优势: - 可以实现 fuse pattern 的自动化匹配,用户不再需要手动执行 fuse 过程 - 为其他基于计算图的优化提供了可能性
总体上看,FX 量化的流程与 eager mode 高度一致,一般仅需将接口替换即可: - prepare_calibration
-> prepare_calibration_fx
- prepare_qat
-> prepare_qat_fx
- convert
-> convert_fx
其中 prepare_calibration_fx
和 prepare_qat_fx
接口中都集成了自动化的 fuse 流程。除此之外,我们也提供了单独的 fuse_fx
接口供 debug 或研究使用
4.2.3.5.1. 限制¶
FX 采用符号执行的方式对模型中的操作进行记录,存在诸多限制,具体见官方文档相关章节 limitations-of-symbolic-tracing
用户模型中若用到了 FX 不支持的操作如控制流等,可以通过 wrap 的方式将其作为一个函数或者方法包装为一个整体,FX 将不再关注它们的内部逻辑,而是将对它们的调用原样保留
我们对 torch.fx.wrap
进行了扩展以支持更多的包装形式,具体说明见 utils.fx_helper.wrap
的接口文档
下面举例对 wrap 的使用进行说明
[3]:
from torch import nn
import torch
from torch.nn import functional as F
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.quantize_fx import QuantizationTracer
from torch.quantization import DeQuantStub
from horizon_plugin_pytorch.quantization.fx.graph_module import GraphModuleWithAttr
class FxWrapExampleNet(nn.Module):
def __init__(self):
super(FxWrapExampleNet, self).__init__()
self.quant = QuantStub()
self.conv = nn.Conv2d(3, 3, 1)
self.bn = nn.BatchNorm2d(3)
self.relu = nn.ReLU()
self.dequant = DeQuantStub()
def forward(self, input):
# 模型主体,需要进行 fuse、量化等操作
x = self.quant(input)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
# 后处理,不需要量化,且包含条件分支
if self.training:
print("Run softmax")
return F.softmax(x, dim=1)
else:
print("Run argmax")
return torch.argmax(x, dim=1)
model = FxWrapExampleNet()
tracer = QuantizationTracer([], [])
graph = tracer.trace(model)
graph.print_tabular()
graph_model = GraphModuleWithAttr(model, graph)
print(graph_model.code)
data = torch.rand(1, 3, 64, 64)
ret = graph_model(data)
Run softmax
opcode name target args kwargs
------------- ------- ------------------------------------ ---------- -------------------------------------------
placeholder input_1 input () {}
call_module conv conv (input_1,) {}
call_module bn bn (conv,) {}
call_module relu relu (bn,) {}
call_function softmax <function softmax at 0x7f0ac35be040> (relu,) {'dim': 1, '_stacklevel': 3, 'dtype': None}
output output output (softmax,) {}
def forward(self, input):
input_1 = input
conv = self.conv(input_1); input_1 = None
bn = self.bn(conv); conv = None
relu = self.relu(bn); bn = None
softmax = torch.nn.functional.softmax(relu, dim = 1, _stacklevel = 3, dtype = None); relu = None
return softmax
可以看到,trace 后的模型中后处理 torch.argmax
以及 print 语句都被丢弃了,如果需要将后处理原样保留,可以将它包装起来
[4]:
from horizon_plugin_pytorch.utils.fx_helper import wrap as fx_wrap
class FxWrapExampleNet(nn.Module):
def __init__(self):
super(FxWrapExampleNet, self).__init__()
self.quant = QuantStub()
self.conv = nn.Conv2d(3, 3, 1)
self.bn = nn.BatchNorm2d(3)
self.relu = nn.ReLU()
self.dequant = DeQuantStub()
# 将后处理包装为一个 method
@fx_wrap
def _post_process(self, model_output):
if self.training:
print("Run softmax")
return F.softmax(model_output, dim=1)
else:
print("Run argmax")
return torch.argmax(model_output, dim=1)
def forward(self, input):
x = self.quant(input)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.dequant(x)
return self._post_process(x)
model = FxWrapExampleNet()
tracer = QuantizationTracer([], [])
graph = tracer.trace(model)
graph.print_tabular()
graph_model = GraphModuleWithAttr(model, graph)
print(graph_model.code)
ret = graph_model(data)
ret = graph_model.eval()(data)
opcode name target args kwargs
----------- ------------- ------------- ---------------- --------
placeholder input_1 input () {}
call_module conv conv (input_1,) {}
call_module bn bn (conv,) {}
call_module relu relu (bn,) {}
get_attr _self _self () {}
call_method _post_process _post_process (_self, relu) {}
output output output (_post_process,) {}
def forward(self, input):
input_1 = input
conv = self.conv(input_1); input_1 = None
bn = self.bn(conv); conv = None
relu = self.relu(bn); bn = None
_self = self._self
_post_process = _self._post_process(relu); _self = relu = None
return _post_process
Run softmax
Run argmax
可以看到,包装起来的后处理在 trace 后的模型中作为一个整体被调用,内部逻辑将被原样保留, print 语句也可以正常打印内容