10.3.2. set_qconfig 书写规范和自定义 qconfig 介绍¶
10.3.2.1. set_qconfig 方法书写规范¶
在对要量化的模型进行定义时,需要实现模型 set_qconfig
方法对量化方式进行配置。
当前设置QConfig接口由 hat.utils.qconfig_manager
提供,set_qconfig
中调用 hat.utils.qconfig_manager
实现对模块Qconfig的设置,例如:
# 注: 这份代码示例只是展示 set_qconfig 方法的实现规则,不是完整的量化模型代码
class Head(nn.Module):
def __init__(self):
super(Head, self).__init__()
self.out_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)
def forward(self):
...
def set_qconfig(self):
# 若网络最后输出层为 conv,可以单独设置为 out_qconfig 得到更高精度的输出
from hat.utils import qconfig_manager
self.out_conv.qconfig = qconfig_manager.get_default_qat_out_qconfig()
class Backbone(nn.Module):
def __init__(self):
super(Backbone, self).__init__()
self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1)
def forward(self):
...
# 当前 Backbone 中没有特殊 layer,也没有需要设置 QConfig=None 的 layer 时,
# 即都需要设置 default_qat_qconfig 时,可以不写 set_qconfig() 方法
# def set_qconfig(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.backbone = Backbone()
self.head = Head()
self.loss = nn.CrossEntropyLoss()
def forward(self):
...
# 需要父模块实现 set_qconfig 方法
def set_qconfig(self):
from hat.utils import qconfig_manager
# 1. 首先指定父模块的 qconfig,
# 如果未对子模块设置 qconfig,子模块会自动使用父模块的 qconfig
self.qconfig = qconfig_manager.get_default_qat_qconfig()
# 2. 如果有某子模块有特殊 layer,实现了 set_qconfig 方法,调用
if self.head is not None:
if hasattr(self.head, "set_qconfig"):
self.head.set_qconfig()
# 3. 如果有子模块不需要设置 Qconfig,需要设置 Qconfig 为 None
if self.loss is not None:
self.loss.qconfig = None
10.3.2.2. 自定义 QAT QConfig 参数¶
HAT支持QAT训练时使用自定义QConfig,只需在config文件的 qat_solver
中配置 qconfig_params
参数即可:
qat_solver = dict(
trainer=qat_trainer,
quantize=True,
...
qconfig_params=dict(
dtype="qint8",
activation_fake_quant="fake_quant",
weight_fake_quant="fake_quant",
activation_qkwargs=dict(
averaging_constant=0,
),
weight_qkwargs=dict(
averaging_constant=1,
),
),
...
)
qconfig_params
主要有五个参数配置项: dtype、 activation_fake_quant、 weight_fake_quant、 activation_qkwargs、 weight_qkwargs。
dtype:量化比特类型,支持
"qint8"
, 缺省时使用默认值"qint8"
。activation_fake_quant:指定activation的量化器,支持
"fake_quant"
、"lsq"
、"pact"
, 缺省时使用默认值"fake_quant"
。weight_fake_quant:指定weight的量化器。支持及使用方式同
activation_fake_quant
。activation_qkwargs:指定activation量化器的参数。
activation_fake_quant
是"fake_quant"
时,activation_qkwargs
可设置参数:activation_qkwargs=dict( observer=MovingAverageMinMaxObserver, # 指定 observer,默认即可,一般不用设置 averaging_constant=0.01, # 设置 scale 的更新系数 )
activation_fake_quant
是"lsq"
时,activation_qkwargs
可设置参数:activation_qkwargs=dict( observer=MovingAverageMinMaxObserver, # 指定 observer,默认即可,一般不用设置 scale=1.0, # 指定初始 scale,默认即可,一般不用设置 zero_point=0.0, # 指定初始 zero_point,默认即可,一般不用设置 use_grad_scaling=False, # 定义scale和 zero_point 的梯度是否由常数归一化,默认 False,默认即可,一般不用设置 )
activation_fake_quant
是"pact"
时,activation_qkwargs
可设置参数:
activation_qkwargs=dict( observer=MovingAverageMinMaxObserver, # 指定 observer,默认即可,一般不用设置 alpha=6.0, # 指定 activation 的 clip 参数,默认 6.0,一般不用设置 )
weight_qkwargs:指 weight量化器的参数。除了
weight_qkwargs
的默认observer
是MovingAveragePerChannelMinMaxObserver
外, 其他参数配置及用法,同activation_qkwargs
。
注意
activation_qkwargs
和 weight_qkwargs
一般是不需要进行设置的,缺省使用默认配置即可。但当使用 calibration 后进行 QAT 训练时,可能需要修改 averaging_constant
。