4.2.3.4. 异构模型QAT¶
4.2.3.4.1. 异构模型与非异构模型的区别¶
异构模型是部署时一部分运行在BPU上,一部分运行在CPU上的模型,而非异构模型部署时则完全运行在BPU上。通常情况下,以下两类模型在部署时会成为异构模型:
包含BPU不支持算子的模型。
由于量化精度误差过大,用户指定某些算子运行在CPU上的模型。
horizon_pytorch_plugin对异构模型与非异构模型支持的区别如下:
异构 |
非异构 |
|
---|---|---|
算子 |
与horizon_nn对接,以horizon_nn支持算子为准,模型可以包括CPU算子。 |
直接与编译器对接,以horizon_pytorch_plugin支持算子为准,模型不能包括CPU算子。 |
接口 |
prepare_calibration_fx: 指定hybrid模式, 根据需要设置hybrid_dict。 |
参考非异构模式用法文档。 |
流程 |
|
|
4.2.3.4.2. 主要接口参数说明¶
异构接口的用法基本与非异构的用法保持一致,仅添加了hybrid等少量参数。详细的接口参数说明见API文档,这里重点描述关键的几个参数。
4.2.3.4.2.1. horizon_plugin_pytorch.quantization.prepare_qat_fx¶
开启hybrid参数,如果没有指定BPU算子退回CPU的需求,可以不设置hybrid_dict。
def prepare_qat_fx(
model: Union[torch.nn.Module, GraphModule],
qconfig_dict: Dict[str, Any] = None,
prepare_custom_config_dict: Dict[str, Any] = None,
optimize_graph: bool = False,
hybrid: bool = False,
hybrid_dict: Dict[str, List] = None,
) -> ObservedGraphModule:
"""Prepare QAT模型
`model`: torch.nn.Module或GraphModule(使用fuse_fx后的模型)
`qconfig_dict`: 定义Qconfig。如果除了qconfig_dict以外,还使用了eager mode在module内定义qconfig的方式,则module内定义的qconfig优先生效。qconfig_dict的配置格式如下:
qconfig_dict = {
# 可选,全局配置
"": qconfig,
# 可选,按module类型配置
"module_type": [(torch.nn.Conv2d, qconfig), ...],
# 可选,按module名配置
"module_name": [("foo.bar", qconfig),...],
# 优先级: global < module_type < module_name < module.qconfig
# 非module类型的算子的qconfig默认与其父module的qconfig保持一致,如果需要单独设置,请将这部分单独封装成module。
}
`prepare_custom_config_dict`: 自定义配置字典
prepare_custom_config_dict = {
# 暂时只支持preserved_attributes。一般而言会自动保留所有属性,这个选项只是以防万一,几乎不会用到。
"preserved_attributes": ["preserved_attr"],
}
`optimize_graph`: 保持cat输入输出scale一致,目前只有在Bernoulli架构下有效。
`hybrid`: 是否使用异构模式。在以下情况下必须打开异构模式:
1. 模型包含BPU不支持的算子或用户希望指定部分BPU算子退回CPU。
2. 用户希望QAT模型与horizon_nn对接进行定点化。
`hybrid_dict`: 定义用户主动指定的CPU算子。
hybrid_dict = {
# 可选,按module类型配置
"module_type": [torch.nn.Conv2d, ...],
# 可选,按module名配置
"module_name": ["foo.bar", ...],
# 优先级: module_type < module_name
# 与qconfig_dict类似,如果想要非module类型的算子运行在CPU上,需要将这部分单独封装成module。
}
"""
4.2.3.4.2.2. horizon_plugin_pytorch.quantization.prepare_calibraiton_fx¶
用法与prepare_qat_fx完全一致,需要注意qconfig使用calibration qconfig。
def prepare_calibration_fx(
model,
qconfig_dict: Dict[str, Any] = None,
prepare_custom_config_dict: Dict[str, Any] = None,
optimize_graph: bool = False,
hybrid: bool = False,
hybrid_dict: Dict[str, List] = None,
)-> ObservedGraphModule:
4.2.3.4.2.3. horizon_plugin_pytorch.quantization.convert_fx¶
异构模式下的convert接口与非异构模式使用方式相同,但异构模型convert得到的定点模型仅用于评测精度,不用于得到最终部署的模型。
def convert_fx(
graph_module: GraphModule,
convert_custom_config_dict: Dict[str, Any] = None,
_remove_qconfig: bool = True,
) -> QuantizedGraphModule:
"""转换QAT模型,仅用于评测定点模型。
`graph_module`: 经过prepare->(calibration)->train之后的模型
`convert_custom_config_dict`: 自定义配置字典
convert_custom_config_dict = {
# 暂时只支持preserved_attributes。一般而言会自动保留所有属性,这个选项只是以防万一,几乎不会用到。
"preserved_attributes": ["preserved_attr"],
}
`_remove_qconfig`: convert之后是否删除qconfig,一般不会用到
"""
4.2.3.4.2.4. horizon_plugin_pytorch.utils.onnx_helper.export_to_onnx¶
非异构模式下,此接口仅用于可视化;异构模式下,此接口还可用于导出onnx对接hb_mapper。
def export_to_onnx(
model,
args,
f,
export_params=True,
verbose=False,
training=TrainingMode.EVAL,
input_names=None,
output_names=None,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
opset_version=11,
do_constant_folding=True,
example_outputs=None,
strip_doc_string=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
enable_onnx_checker=False,
):
"""此接口与torch.onnx.export基本一致,隐藏了无需修改的参数,需要的注意参数有:
`model`: 需要export的模型
`args`: 模型输入,用于trace模型
`f`: 保存的onnx文件名或文件描述符
`operator_export_type`: 算子导出类型
1. 对于非异构模型,onnx仅用于可视化,不需要保证实际可用,使用默认值OperatorExportTypes.ONNX_FALLTHROUGH
2. 对于异构模型,onnx需要保证实际可用,使用None确保导出的为标准onnx算子。
`opset_version`: 只能为11,plugin在opset 11中注册了特定的映射规则。
注意:如果使用公版torch.onnx.export,需要确保上述参数设置正确,
并且import horizon_plugin_pytorch.utils._register_onnx_ops
以向opset 11中注册特定的映射规则。
"""
4.2.3.4.3. 使用流程¶
改造浮点模型。
插入QuantStub与DeQuantStub,保持与非异构的用法一致。
如果第一个op是cpu op,那么不需要插入QuantStub。
如果最后一个op是cpu op,那么可以不用插入DeQuantStub。
对于非module的运算,如果需要单独设置qconfig或指定其运行在CPU上,需要将其封装成module,参考示例中的_SeluModule。
设置march。
设置qconfig。保留非异构模式下在module内设置qconfig的配置方式,除此以外,还可以通过prepare_qat_fx接口的qconfig_dict参数传入qconfig,具体用法见接口参数说明。
对于BPU op,必须保证有qconfig,如果其输入op不为QuantStub,那么还需要保证该输入op有activation qconfig。
对于CPU op,qconfig不会对其产生任何影响,但如果后面接BPU op,则必须有qconfig。
推荐设置方式:先设置全局qconfig为horizon_plugin_pytorch.quantization.get_default_qat_qconfig(),在此基础上根据需求修改,一般而言,只需要对int16和高精度输出的op单独设置qconfig。
设置hybrid_dict。可选,具体用法见接口参数说明,如果没有主动指定的CPU算子,可以不设置hybrid_dict。
调用prepare_calibration_fx。可选,如果任务简单,直接qat精度就可以达标,也可以跳到第7步。一般来说calibration对qat精度有益无害,打印calibration模型,可以看到需要统计量化参数的地方插入了CalibFakeQuantize。示例中的conv4结构如下:
(conv4): Conv2d( 3, 3, kernel_size=(1, 1), stride=(1, 1) (weight_fake_quant): CalibFakeQuantize( (activation_post_process): NoopObserver() ) (activation_post_process): CalibFakeQuantize( (activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([])) ) )
calibration。calibration过程需要模型在eval状态跑若干遍forward。
调用prepare_qat_fx。打印qat模型,可以看到需要伪量化的地方插入了FakeQuantize,示例中的conv4结构如下:
(conv4): Conv2d( 3, 3, kernel_size=(1, 1), stride=(1, 1) (weight_fake_quant): FakeQuantize( fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0025, 0.0037, 0.0029]), zero_point=tensor([0, 0, 0]) (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.2484, -0.4718, -0.3689]), max_val=tensor([ 0.3239, -0.0056, 0.3312])) ) (activation_post_process): None )
为了验证模型的正确性,prepare_qat_fx之后可以跳过8,9步,先按照第10步导出onnx查看模型结构,验证无问题后再执行第8步。
training。
调用convert_fx。可选,没有评测定点模型精度的需求时可以跳过。
调用export_to_onnx。也可以使用torch.onnx.export但需要遵守export_to_onnx接口说明中的注意事项。
使用hb_mapper转换onnx模型。
4.2.3.4.4. 示例¶
import numpy as np
import torch
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.nn import qat
from horizon_plugin_pytorch.quantization import (
get_default_calib_qconfig,
get_default_qat_qconfig,
get_default_qat_out_qconfig,
prepare_calibration_fx,
prepare_qat_fx,
convert_fx,
)
from torch import nn
from torch.quantization import DeQuantStub, QuantStub
from horizon_plugin_pytorch.utils.onnx_helper import export_to_onnx
class _ConvBlock(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 1)
self.prelu = torch.nn.PReLU()
def forward(self, input):
x = self.conv(input)
x = self.prelu(x)
return torch.nn.functional.selu(x)
# 封装functional selu为module,便于单独设置
class _SeluModule(nn.Module):
def forward(self, input):
return torch.nn.functional.selu(input)
class HybridModel(nn.Module):
def __init__(self, channels=3):
super().__init__()
# 插入QuantStub
self.quant = QuantStub()
self.conv0 = nn.Conv2d(channels, channels, 1)
self.prelu = torch.nn.PReLU()
self.conv1 = _ConvBlock(channels)
self.conv2 = nn.Conv2d(channels, channels, 1)
self.conv3 = nn.Conv2d(channels, channels, 1)
self.conv4 = nn.Conv2d(channels, channels, 1)
self.selu = _SeluModule()
# 插入DequantStub
self.dequant = DeQuantStub()
self.identity = torch.nn.Identity()
def forward(self, input):
x = self.quant(input)
x = self.conv0(x)
x = self.identity(x)
x = self.prelu(x)
x = torch.nn.functional.selu(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.identity(x)
x = self.conv4(x)
x = self.selu(x)
return self.dequant(x)
# 设置march
set_march(March.BAYES)
data_shape = [1, 3, 224, 224]
data = torch.rand(size=data_shape)
model = HybridModel()
# float模型的推理不要放在prepare_qat_fx之后,prepare_qat_fx会对float模型做inplace修改
float_res = model(data)
calibration_model = prepare_calibration_fx(
model,
{
# calibration fake quant只做统计,qat阶段未使用的calibration fake quant
# 会被自动去除,可以不用对高精度输出op做特殊设置
"": get_default_calib_qconfig(),
},
hybrid=True,
hybrid_dict={
"module_name": ["conv1.conv", "conv3"],
"module_type": [_SeluModule],
},
)
# calibration阶段需确保原有模型不会发生变化
calibration_model.eval()
for i in range(5):
calibration_model(torch.rand(size=data_shape))
qat_model = prepare_qat_fx(
calibration_model,
{
"": get_default_qat_qconfig(),
# selu为cpu算子,conv4实际上是bpu模型的输出,设置为高精度输出
"module_name": [("conv4", get_default_qat_out_qconfig())]
},
hybrid=True,
hybrid_dict={
"module_name": ["conv1.conv", "conv3"],
"module_type": [_SeluModule],
},
)
# qat模型的推理不要放在convert_fx之后,convert_fx会对qat模型做inplace修改
qat_res = qat_model(data)
# qat training start
# ......
# qat training end
# 导出qat.onnx
export_to_onnx(
qat_model,
data,
"qat.onnx",
enable_onnx_checker=True,
operator_export_type=None,
)
# 评测定点模型
quantize_model = convert_fx(qat_model)
quantize_res = quantize_model(data)
打印calibration模型的结果。
HybridModel(
(quant): QuantStub(
(activation_post_process): CalibFakeQuantize(
(activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([]))
)
)
(conv0): Conv2d(
3, 3, kernel_size=(1, 1), stride=(1, 1)
(weight_fake_quant): CalibFakeQuantize(
(activation_post_process): NoopObserver()
)
(activation_post_process): CalibFakeQuantize(
(activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([]))
)
)
(prelu): PReLU(num_parameters=1)
(conv1): _ConvBlock(
(conv): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
(prelu): PReLU(num_parameters=1)
)
(conv2): Conv2d(
3, 3, kernel_size=(1, 1), stride=(1, 1)
(weight_fake_quant): CalibFakeQuantize(
(activation_post_process): NoopObserver()
)
(activation_post_process): CalibFakeQuantize(
(activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([]))
)
)
(conv3): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
(conv4): Conv2d(
3, 3, kernel_size=(1, 1), stride=(1, 1)
(weight_fake_quant): CalibFakeQuantize(
(activation_post_process): NoopObserver()
)
(activation_post_process): CalibFakeQuantize(
(activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([]))
)
)
(selu): _SeluModule()
(dequant): DeQuantStub()
(identity): Identity()
(prelu_input_dequant): DeQuantStub()
(selu_1_activation_post_process): CalibFakeQuantize(
(activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([]))
)
(conv3_activation_post_process): CalibFakeQuantize(
(activation_post_process): CalibObserver(CalibObserver() calib_bin_edges=tensor([]) calib_hist=tensor([]))
)
(conv3_input_dequant): DeQuantStub()
(selu_2_input_dequant): DeQuantStub()
)
def forward(self, input):
input_1 = input
quant = self.quant(input_1); input_1 = None
conv0 = self.conv0(quant); quant = None
identity = self.identity(conv0); conv0 = None
prelu_input_dequant_0 = self.prelu_input_dequant(identity); identity = None
prelu = self.prelu(prelu_input_dequant_0); prelu_input_dequant_0 = None
selu = torch.nn.functional.selu(prelu, inplace = False); prelu = None
conv1_conv = self.conv1.conv(selu); selu = None
conv1_prelu = self.conv1.prelu(conv1_conv); conv1_conv = None
selu_1 = torch.nn.functional.selu(conv1_prelu, inplace = False); conv1_prelu = None
selu_1_activation_post_process = self.selu_1_activation_post_process(selu_1); selu_1 = None
conv2 = self.conv2(selu_1_activation_post_process); selu_1_activation_post_process = None
conv3_input_dequant_0 = self.conv3_input_dequant(conv2); conv2 = None
conv3 = self.conv3(conv3_input_dequant_0); conv3_input_dequant_0 = None
conv3_activation_post_process = self.conv3_activation_post_process(conv3); conv3 = None
identity_1 = self.identity(conv3_activation_post_process); conv3_activation_post_process = None
conv4 = self.conv4(identity_1); identity_1 = None
selu_2_input_dequant_0 = self.selu_2_input_dequant(conv4); conv4 = None
selu_2 = torch.nn.functional.selu(selu_2_input_dequant_0, inplace = False); selu_2_input_dequant_0 = None
dequant = self.dequant(selu_2); selu_2 = None
return dequant
打印qat模型的结果。
HybridModel(
(quant): QuantStub(
(activation_post_process): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0078]), zero_point=tensor([0])
(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.9995]), max_val=tensor([0.9995]))
)
)
(conv0): Conv2d(
3, 3, kernel_size=(1, 1), stride=(1, 1)
(weight_fake_quant): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0038, 0.0041, 0.0016]), zero_point=tensor([0, 0, 0])
(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.4881, -0.4944, 0.0787]), max_val=tensor([-0.1213, 0.5284, 0.1981]))
)
(activation_post_process): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0064]), zero_point=tensor([0])
(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.8159]), max_val=tensor([0.8159]))
)
)
(prelu): PReLU(num_parameters=1)
(conv1): _ConvBlock(
(conv): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
(prelu): PReLU(num_parameters=1)
)
(conv2): Conv2d(
3, 3, kernel_size=(1, 1), stride=(1, 1)
(weight_fake_quant): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0040, 0.0044, 0.0040]), zero_point=tensor([0, 0, 0])
(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.5044, -0.4553, -0.5157]), max_val=tensor([0.1172, 0.5595, 0.4104]))
)
(activation_post_process): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0059]), zero_point=tensor([0])
(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.7511]), max_val=tensor([0.7511]))
)
)
(conv3): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
(conv4): Conv2d(
3, 3, kernel_size=(1, 1), stride=(1, 1)
(weight_fake_quant): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([0.0025, 0.0037, 0.0029]), zero_point=tensor([0, 0, 0])
(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.2484, -0.4718, -0.3689]), max_val=tensor([ 0.3239, -0.0056, 0.3312]))
)
(activation_post_process): None
)
(selu): _SeluModule()
(dequant): DeQuantStub()
(identity): Identity()
(prelu_input_dequant): DeQuantStub()
(selu_1_activation_post_process): _WrappedCalibFakeQuantize(
(activation_post_process): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0042]), zero_point=tensor([0])
(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.5301]), max_val=tensor([0.5301]))
)
)
(conv3_activation_post_process): _WrappedCalibFakeQuantize(
(activation_post_process): FakeQuantize(
fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.0072]), zero_point=tensor([0])
(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.9156]), max_val=tensor([0.9156]))
)
)
(conv3_input_dequant): DeQuantStub()
(selu_2_input_dequant): DeQuantStub()
)
def forward(self, input):
input_1 = input
quant = self.quant(input_1); input_1 = None
conv0 = self.conv0(quant); quant = None
identity = self.identity(conv0); conv0 = None
prelu_input_dequant_0 = self.prelu_input_dequant(identity); identity = None
prelu = self.prelu(prelu_input_dequant_0); prelu_input_dequant_0 = None
selu = torch.nn.functional.selu(prelu, inplace = False); prelu = None
conv1_conv = self.conv1.conv(selu); selu = None
conv1_prelu = self.conv1.prelu(conv1_conv); conv1_conv = None
selu_1 = torch.nn.functional.selu(conv1_prelu, inplace = False); conv1_prelu = None
selu_1_activation_post_process = self.selu_1_activation_post_process(selu_1); selu_1 = None
conv2 = self.conv2(selu_1_activation_post_process); selu_1_activation_post_process = None
conv3_input_dequant_0 = self.conv3_input_dequant(conv2); conv2 = None
conv3 = self.conv3(conv3_input_dequant_0); conv3_input_dequant_0 = None
conv3_activation_post_process = self.conv3_activation_post_process(conv3); conv3 = None
identity_1 = self.identity(conv3_activation_post_process); conv3_activation_post_process = None
conv4 = self.conv4(identity_1); identity_1 = None
selu_2_input_dequant_0 = self.selu_2_input_dequant(conv4); conv4 = None
selu_2 = torch.nn.functional.selu(selu_2_input_dequant_0, inplace = False); selu_2_input_dequant_0 = None
dequant = self.dequant(selu_2); selu_2 = None
return dequant
导出的onnx如图所示,红色圈出部分为CPU算子。