6.1.3.9. 数据校准

在量化训练(QAT)中,一个重要的步骤是确定量化参数 scale,一个合理的 scale 能够显著提升模型训练结果和加快模型的收敛速度。 Calibration 是通过用浮点模型在训练集上跑少数 batch 的数据(只跑 forward 过程,没有backward),统计这些数据的分布直方图,通过一定方法去计算出 min_value 和 max_value,然后可以用这些 min_value 和 max_value 去获取 scale。当 QAT 的训练精度上不去的时候,在 QAT 的开始之前使用 calibration 做量化参数的微调,获取 scale,可以为qat提供更好的量化初始化参数,提升收敛速度和精度。

6.1.3.9.1. 1. 如何定义 Calibration 模型

  • 默认不需要对现有模型做任何修改

    类似于定义量化模型时需要设置 QAT QConfig,Calibration 时也需要对模型设置 Calibration QConfig。不过,Calibration QConfig 的设置相对来说比较简单,HAT 框架已经实现对模型 Calibration QConfig 的默认设置,用户无需对模型做任何修改,即可使用 Calibration。

  • 自定义模型子模块 Calibration QConfig

    在上文的默认情况下,会为模型的所有 Module(继承自 nn.Module) 设置 Calibration QConfig。因此,Calibration 时也就会对所有 Module 的特征分布进行统计。如果有特殊需求,可以在模型内自定义实现 set_calibration_qconfig 方法:


class Classifier(nn.Module):
    def __init__(self,):
        ...
    
    def forward(self, x):
        ...
        
    # 自定义要做 Calibration 的模块
    def set_calibration_qconfig(self, ):
        
        # 比如可以设置 Loss 的 qconfig 为 None,就会不再对 Loss 做 Calibration,
        # 可以一定程度减少统计量,提升 Calibration 速度,降低显存占用
        if self.loss is not None:
            self.loss.qconfig = None

6.1.3.9.2. 2. 浮点模型做 Calibration

HAT 中集成了 Calibration 功能,由浮点模型做 Calibration 命令和正常训练相似,只需执行以下命令即可:

python3 tools/train.py --stage calibration ...

需要注意的是 config 文件中 calibration_trainer 中的一些配置:


# Note: The transforms of the dataset during calibration can be
# consistent with that during training or validation, or customized.
# Default used `val_batch_processor`.
calibration_data_loader = copy.deepcopy(data_loader)
calibration_data_loader.pop('sampler')  # Calibration do not support DDP or DP
calibration_batch_processor = copy.deepcopy(val_batch_processor)

calibration_trainer = dict(
    type="Calibrator",
    model=model,
    # 1. 设置 data_loader 和 batch_processor
    data_loader=calibration_data_loader,
    batch_processor=calibration_batch_processor,
    # 2. 设置 calibration 迭代的 batch 数目
    num_stages=30,
    ...   
)

1. 数据集的设置:

做 Calibration 的数据集(dataset)不能是测试集(可以是训练集或其他数据),但是做 Calibration 时用于数据增强的 transforms 可以和正常训练时的 transforms 保持一致,但是也可以设置成和 validation 的 transforms 一致,也可以自定义 transforms。(哪种实验效果最好,暂时没有定论,都可以尝试。)

2. Calibration 迭代的图片数目(可供参考):

  • classification: 图片张数一般可以 500~1500 张就可以取得不错的效果。

  • segmentation && detection: 图片张数可以 100~300 张左右。

    注:这些图片张数具体数目也不是固定的,上面的建议只是从已有的实验中总结的,可根据实际情况调整。

6.1.3.9.3. 3. 使用 Calibration 模型做 QAT 训练

qat_trainer = dict(
    type="distributed_data_parallel_trainer",
    model=model,
    model_convert_pipeline=dict(
        type="ModelConvertPipeline",
        qat_mode="fuse_bn",
        # (可选) 设置 QAT 训练时 scale 更新系数
        qconfig_params=dict(
            activation_qkwargs=dict(
                averaging_constant=0,
            ),
            weight_qkwargs=dict(
                averaging_constant=1,
            ),
        ),
        converters=[
            dict(type="Float2QAT"),
            dict(
                type="LoadCheckpoint",
                checkpoint_path=os.path.join(
                    ckpt_dir, "calibration-checkpoint-best.pth.tar"
                ),
            ),
        ],
    ),
)
  • QAT 时 averaging_constant 参数设置:

    量化时 scale 参数的更新规则是 scale = (1 - averaging_constant) * scale + averaging_constant * current_scale

    在已有的一些实验中(主要是图像分类任务实验)发现,做完 calibration 后,把 activation 的 scale 固定住,不进行更新,即设置 activation 的averaging_constant=0, 并设置 weight 的averaging_constant=1,效果可能会相对略好一些。

    注:但这种设置不是适用于所有任务,在 lidar 任务中,固定 scale,精度也可能会变差。可根据实际情况调整。

接下来只需要执行正常的 QAT 训练命令,即可启动 QAT 训练:

python3 tools/train.py --stage qat ...