10.4.6.1. Calibration (Experimental Support)

在 plugin 的量化训练中,一个重要的步骤是确定量化参数 scale ,一个合理的 scale 能够显著提升模型训练结果和加快模型的收敛速度。 一种常见的 scale 的计算方法为:

def compute_scale(data, quant_min, quant_max):
    fmax = data.abs().max()
    scale = fmax * 2 / (quant_max - quant_min)
    return scale

当计算 feature map 的 scale 时,由于每次 forward 只能计算出当前 batch 的 fmax,对于整个数据集来说,每次 forward 计算出来的 feature map 可能不准确。因此,引入了 calibration 方法。

Calibration 方法

Calibration 方法是在量化训练之前,使用浮点模型统计计算 scale 的方法。步骤为:

  1. 浮点模型 forward,collect 浮点模型的统计数据。

  2. 使用步骤 1 的统计数据,通过 Calibration 得到 feature map 的量化参数。

  3. 使用步骤 2 得到的量化参数,初始化量化训练模型的量化参数。

  4. 在步骤 3 的基础上进行量化训练。

Plugin Calibration 使用方法

plugin 提供了默认的 Calibration 配置,用户可以通过设置 float.qconfig = get_default_calib_qconfig() 来使用 calibration 功能。

horizon.quantization.get_default_calib_qconfig()

plugin Calibration 的限制

  1. 只支持对 feature map 做 Calibration 。

  2. 不支持 train() 模式和 eval() 模式行为不一致的 Module 。