10.4.6.2. Calibration v2(Experimental Support)
Horizon Plugin Pytorch 于 1.2.1 版本后支持了新的 calibration 用法,与原有 calibration 相比,新的 calibration 支持更多的 calibration 方法,用法更灵活,推荐您优先尝试新版 calibration 用法。原有 calibration 用法依然兼容,但在之后的版本中会逐渐弃用。
使用流程
calibration 与 QAT 的整体流程如下图所示:
下面分别介绍各个步骤:
将浮点模型转化 QAT 模型。参考 plugin快速上手章节中的 设置BPU架构 、 算子融合 和 浮点模型转为量化模型 小节。使用
prepare_qat方法转化浮点模型前,需要为模型设置qconfig。model.qconfig = horizon.quantization.get_default_qconfig()
get_default_qconfig可以为weight和activation设置不同的fake_quant和observer。目前,支持的fake quant方法有 “fake_quant”、”lsq” 和 “pact”,支持的observer有 “min_max”、 “fixed_scale”、”clip”、”percentile” 和 “clip_std”。如无特殊需求,activation_fake_quant和weight_fake_quant推荐使用默认的 “fake_quant” 方法,weight_observer使用默认的 “min_max”。如果为 QAT 阶段设置 qconfig ,activation_observer推荐使用默认的 “min_max”,如果为 calibration 阶段设置 qconfig ,activation_observer推荐使用 “percentile”。 calibration 可选observer有 “min_max”、 “percentile” 和 “clip_std”, 特殊用法和调试技巧见 calibration 经验总结。def get_default_qconfig( activation_fake_quant: Optional[str] = "fake_quant", weight_fake_quant: Optional[str] = "fake_quant", activation_observer: Optional[str] = "min_max", weight_observer: Optional[str] = "min_max", activation_qkwargs: Optional[Dict] = None, weight_qkwargs: Optional[Dict] = None, ):
设置
fake quantize状态为CALIBRATION。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.CALIBRATION)
fake quantize一共有三种状态,分别需要在QAT、calibration、validation前将模型的fake quantize设置为对应的状态。在 calibration 状态下,仅观测各算子输入输出的统计量。在 QAT 状态下,除观测统计量外还会进行伪量化操作。而在 validation 状态下,不会观测统计量,仅进行伪量化操作。class FakeQuantState(Enum): QAT = "qat" CALIBRATION = "calibration" VALIDATION = "validation"
calibration 。把准备好的校准数据喂给模型,模型在 forward 过程中由 observer 观测相关统计量。
设置
fake quantize状态为VALIDATION。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION)
验证
calibration效果。如果效果满意,则进入步骤 7 ,不满意则调整calibration qconfig中的参数继续 calibration 。从浮点模型开始重新按照步骤 2 的流程构建 QAT 模型,需要注意
qconfig设置与 calibration 阶段的区别。加载 calibration 得到的参数。
horizon.quantization.load_observer_params(calibration_model, qat_model)
设置
fake quantize状态为QAT。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.QAT)
QAT 训练。
设置
fake quantize状态为VALIDATION,并验证 QAT 模型精度。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION)
使用限制
不支持 train() 模式和 eval() 模式行为不一致的Module。