4.2.3.1. 浮点模型准备

浮点模型由对数据进行操作的层或是模块组成。torch.nn模块中提供了构建浮点模型所需的所有网络块。PyTorch 中的所有网络模块都继承自 torch.nn.Module。一个网络模型本身是一个由其他各个小的网络模块构成的网络模块。通过这种内嵌的结构用户可以较容易地构建和管理复杂的网络架构。用户可以直接使用 PyTorch 提供的网络模块来构建浮点模型。由于量化是以模块为基础的,因此有必要在量化之前对模型定义进行一些修改,具体有以下几个方面:

4.2.3.1.1. 算子替换

为了浮点模型能够向量化模型进行转换,将需要对输出进行量化的函数形式(functional)的操作转化为模块(Module)。(例如使用 torch.nn.ReLU 来替换torch.nn.functional.relu)。具体需要替换的算子可以参阅当前文档 API REFERENCE 中 “支持的算子” 部分。

4.2.3.1.2. 插入量化和反量化节点

为了后续的量化训练和定点预测,需要在整个模型的输入节点前插入量化节点,输出节点后插入反量化节点,具体到实现上,量化模型整体以 QuantStub 开始,以 DeQuantStub 结束。但是如果最后一层的输出为 class_idx 等非量化数据(在 qat 模型中使用 Tensor 类型表示而没有使用 QTensor 的),则不需要 DeQuantStub。下方表格列出了作为模型输出层时不需要 DeQuantStub 的算子。

作为模型输出层时不需要 DeQuantStub 的算子
torch.Tensor.argmax / torch.argmax
horizon_plugin_pytorch.functional.argmax
horizon_plugin_pytorch.functional.filter
torch.max (返回值是 index 的部分无需反量化)

4.2.3.1.3. 设置量化参数

通过对模型的 qconfig 属性赋值来指定模型的哪些部分需要被量化。例如,使用 model.conv1.qconfig = None 设置 model.conv 层不被量化。再比如使用 model1.linear1.qconfig = custom_qconfig 设置 model.linear1 会使用 custom_qconfig 而不使用全局的 qconfig。

4.2.3.1.4. 自定义浮点模型的例子

import torch
import torch.nn.quantized as nnq
import horizon_plugin_pytorch as horizon
from torch.quantization import QConfig, DeQuantStub
from torch import nn
from horizon_plugin_pytorch.quantization import (
    fuse_known_modules
    QuantStub 
)


class ExampleNet(nn.Module):
    def __init__(self):
        super(ExampleNet, self).__init__()
        self.quant = QuantStub()
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
        self.bn = nn.BatchNorm2d(num_features=1)
        # add 操作必须使用 FloatFunctional
        self.add = nn.quantized.FloatFunctional()
        self.act = nn.ReLU()
        self.out_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
        self.dequant = DeQuantStub()

    def forward(self, x):
        # 量化模型整体一般以 QuantStub 开始,以 DeQuantStub 结束
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        # 如果要将 add fuse 到 conv 中,add 的第一个输入必须来自于要 fuse 到的 conv
        # 注意此处 add 的调用方式,使用 FloatFunctional 时必须使用它的具体方法,而不能直接使用 forward
        x = self.add.add(x, y)
        x = self.act(x)
        x = self.out_conv(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        from horizon_plugin_pytorch import quantization

        torch.quantization.fuse_modules(
            self,
            ["conv", "bn", "add", "act"],
            inplace=True,
            fuser_func=quantization.fuse_known_modules,
        )

    def set_qconfig(self):
        # 这里可以不调用子模块的 set_qconfig 方法,没有 qconfig 的子模块会自动使用父模块的 qconfig
        self.qconfig = horizon.quantization.get_default_qat_qconfig()
        # 若网络最后输出层为 conv,可以单独设置为 out_qconfig 得到更高精度的输出
        self.out_conv.qconfig = (
            horizon.quantization.get_default_qat_out_qconfig()
        )