10.4.6.2. Calibration v2(Experimental Support)
Horizon Plugin Pytorch 于 1.2.1 版本后支持了新的 calibration
用法,与原有 calibration 相比,新的 calibration 支持更多的 calibration 方法,用法更灵活,推荐您优先尝试新版 calibration 用法。原有 calibration 用法依然兼容,但在之后的版本中会逐渐弃用。
使用流程
calibration 与 QAT 的整体流程如下图所示:
下面分别介绍各个步骤:
将浮点模型转化 QAT 模型。参考 plugin快速上手章节中的 设置BPU架构 、 算子融合 和 浮点模型转为量化模型 小节。使用
prepare_qat
方法转化浮点模型前,需要为模型设置qconfig
。model.qconfig = horizon.quantization.get_default_qconfig()
get_default_qconfig
可以为weight
和activation
设置不同的fake_quant
和observer
。目前,支持的fake quant
方法有 “fake_quant”、”lsq” 和 “pact”,支持的observer
有 “min_max”、 “fixed_scale”、”clip”、”percentile” 和 “clip_std”。如无特殊需求,activation_fake_quant
和weight_fake_quant
推荐使用默认的 “fake_quant” 方法,weight_observer
使用默认的 “min_max”。如果为 QAT 阶段设置 qconfig ,activation_observer
推荐使用默认的 “min_max”,如果为 calibration 阶段设置 qconfig ,activation_observer
推荐使用 “percentile”。 calibration 可选observer
有 “min_max”、 “percentile” 和 “clip_std”, 特殊用法和调试技巧见 calibration 经验总结。def get_default_qconfig( activation_fake_quant: Optional[str] = "fake_quant", weight_fake_quant: Optional[str] = "fake_quant", activation_observer: Optional[str] = "min_max", weight_observer: Optional[str] = "min_max", activation_qkwargs: Optional[Dict] = None, weight_qkwargs: Optional[Dict] = None, ):
设置
fake quantize
状态为CALIBRATION
。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.CALIBRATION)
fake quantize
一共有三种状态,分别需要在QAT
、calibration
、validation
前将模型的fake quantize
设置为对应的状态。在 calibration 状态下,仅观测各算子输入输出的统计量。在 QAT 状态下,除观测统计量外还会进行伪量化操作。而在 validation 状态下,不会观测统计量,仅进行伪量化操作。class FakeQuantState(Enum): QAT = "qat" CALIBRATION = "calibration" VALIDATION = "validation"
calibration 。把准备好的校准数据喂给模型,模型在 forward 过程中由 observer 观测相关统计量。
设置
fake quantize
状态为VALIDATION
。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION)
验证
calibration
效果。如果效果满意,则进入步骤 7 ,不满意则调整calibration qconfig
中的参数继续 calibration 。从浮点模型开始重新按照步骤 2 的流程构建 QAT 模型,需要注意
qconfig
设置与 calibration 阶段的区别。加载 calibration 得到的参数。
horizon.quantization.load_observer_params(calibration_model, qat_model)
设置
fake quantize
状态为QAT
。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.QAT)
QAT 训练。
设置
fake quantize
状态为VALIDATION
,并验证 QAT 模型精度。horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION)
使用限制
不支持 train()
模式和 eval()
模式行为不一致的Module。