10.4.6.1. Calibration (Experimental Support)
在 plugin 的量化训练中,一个重要的步骤是确定量化参数 scale
,一个合理的 scale
能够显著提升模型训练结果和加快模型的收敛速度。
一种常见的 scale
的计算方法为:
def compute_scale(data, quant_min, quant_max):
fmax = data.abs().max()
scale = fmax * 2 / (quant_max - quant_min)
return scale
当计算 feature map 的 scale
时,由于每次 forward 只能计算出当前 batch 的 fmax
,对于整个数据集来说,每次 forward 计算出来的 feature map 可能不准确。因此,引入了 calibration
方法。
Calibration 方法
Calibration 方法是在量化训练之前,使用浮点模型统计计算 scale
的方法。步骤为:
浮点模型 forward,collect 浮点模型的统计数据。
使用步骤 1 的统计数据,通过
Calibration
得到 feature map 的量化参数。使用步骤 2 得到的量化参数,初始化量化训练模型的量化参数。
在步骤 3 的基础上进行量化训练。
Plugin Calibration 使用方法
plugin 提供了默认的 Calibration 配置,用户可以通过设置 float.qconfig = get_default_calib_qconfig()
来使用 calibration 功能。
horizon.quantization.get_default_calib_qconfig()
plugin Calibration 的限制
只支持对 feature map 做 Calibration 。
不支持
train()
模式和eval()
模式行为不一致的 Module 。