6.1.3.9. 数据校准¶
在量化训练(QAT)中,一个重要的步骤是确定量化参数 scale,一个合理的 scale 能够显著提升模型训练结果和加快模型的收敛速度。 Calibration 是通过用浮点模型在训练集上跑少数 batch 的数据(只跑 forward 过程,没有backward),统计这些数据的分布直方图,通过一定方法去计算出 min_value 和 max_value,然后可以用这些 min_value 和 max_value 去获取 scale。当 QAT 的训练精度上不去的时候,在 QAT 的开始之前使用 calibration 做量化参数的微调,获取 scale,可以为qat提供更好的量化初始化参数,提升收敛速度和精度。
6.1.3.9.1. 1. 如何定义 Calibration 模型¶
默认不需要对现有模型做任何修改
类似于定义量化模型时需要设置 QAT QConfig,Calibration 时也需要对模型设置 Calibration QConfig。不过,Calibration QConfig 的设置相对来说比较简单,HAT 框架已经实现对模型 Calibration QConfig 的默认设置,用户无需对模型做任何修改,即可使用 Calibration。
自定义模型子模块 Calibration QConfig
在上文的默认情况下,会为模型的所有 Module(继承自 nn.Module) 设置 Calibration QConfig。因此,Calibration 时也就会对所有 Module 的特征分布进行统计。如果有特殊需求,可以在模型内自定义实现
set_calibration_qconfig
方法:
class Classifier(nn.Module):
def __init__(self,):
...
def forward(self, x):
...
# 自定义要做 Calibration 的模块
def set_calibration_qconfig(self, ):
# 比如可以设置 Loss 的 qconfig 为 None,就会不再对 Loss 做 Calibration,
# 可以一定程度减少统计量,提升 Calibration 速度,降低显存占用
if self.loss is not None:
self.loss.qconfig = None
6.1.3.9.2. 2. 浮点模型做 Calibration¶
HAT 中集成了 Calibration 功能,由浮点模型做 Calibration 命令和正常训练相似,只需执行以下命令即可:
python3 tools/train.py --stage calibration ...
需要注意的是 config
文件中 calibration_trainer
中的一些配置:
# Note: The transforms of the dataset during calibration can be
# consistent with that during training or validation, or customized.
# Default used `val_batch_processor`.
calibration_data_loader = copy.deepcopy(data_loader)
calibration_data_loader.pop('sampler') # Calibration do not support DDP or DP
calibration_batch_processor = copy.deepcopy(val_batch_processor)
calibration_trainer = dict(
type="Calibrator",
model=model,
# 1. 设置 data_loader 和 batch_processor
data_loader=calibration_data_loader,
batch_processor=calibration_batch_processor,
# 2. 设置 calibration 迭代的 batch 数目
num_stages=30,
...
)
1. 数据集的设置:
做 Calibration 的数据集(dataset)不能是测试集(可以是训练集或其他数据),但是做 Calibration 时用于数据增强的 transforms 可以和正常训练时的 transforms 保持一致,但是也可以设置成和 validation 的 transforms 一致,也可以自定义 transforms。(哪种实验效果最好,暂时没有定论,都可以尝试。)
2. Calibration 迭代的图片数目(可供参考):
classification: 图片张数一般可以 500~1500 张就可以取得不错的效果。
segmentation && detection: 图片张数可以 100~300 张左右。
注:这些图片张数具体数目也不是固定的,上面的建议只是从已有的实验中总结的,可根据实际情况调整。
6.1.3.9.3. 3. 使用 Calibration 模型做 QAT 训练¶
qat_trainer = dict(
type="distributed_data_parallel_trainer",
model=model,
model_convert_pipeline=dict(
type="ModelConvertPipeline",
qat_mode="fuse_bn",
# (可选) 设置 QAT 训练时 scale 更新系数
qconfig_params=dict(
activation_qkwargs=dict(
averaging_constant=0,
),
weight_qkwargs=dict(
averaging_constant=1,
),
),
converters=[
dict(type="Float2QAT"),
dict(
type="LoadCheckpoint",
checkpoint_path=os.path.join(
ckpt_dir, "calibration-checkpoint-best.pth.tar"
),
),
],
),
)
QAT 时 averaging_constant 参数设置:
量化时 scale 参数的更新规则是
scale = (1 - averaging_constant) * scale + averaging_constant * current_scale
。在已有的一些实验中(主要是图像分类任务实验)发现,做完 calibration 后,把 activation 的 scale 固定住,不进行更新,即设置 activation 的
averaging_constant=0
, 并设置 weight 的averaging_constant=1
,效果可能会相对略好一些。注:但这种设置不是适用于所有任务,在 lidar 任务中,固定 scale,精度也可能会变差。可根据实际情况调整。
接下来只需要执行正常的 QAT 训练命令,即可启动 QAT 训练:
python3 tools/train.py --stage qat ...