4.2.3.1. 浮点模型准备¶
浮点模型由对数据进行操作的层或是模块组成。torch.nn模块中提供了构建浮点模型所需的所有网络块。PyTorch 中的所有网络模块都继承自 torch.nn.Module。一个网络模型本身是一个由其他各个小的网络模块构成的网络模块。通过这种内嵌的结构用户可以较容易地构建和管理复杂的网络架构。用户可以直接使用 PyTorch 提供的网络模块来构建浮点模型。由于量化是以模块为基础的,因此有必要在量化之前对模型定义进行一些修改,具体有以下几个方面:
4.2.3.1.1. 算子替换¶
为了浮点模型能够向量化模型进行转换,将需要对输出进行量化的函数形式(functional)的操作转化为模块(Module)。(例如使用 torch.nn.ReLU 来替换torch.nn.functional.relu)。具体需要替换的算子可以参阅当前文档 API REFERENCE 中 “支持的算子” 部分。
4.2.3.1.2. 插入量化和反量化节点¶
为了后续的量化训练和定点预测,需要在整个模型的输入节点前插入量化节点,输出节点后插入反量化节点,具体到实现上,量化模型整体以 QuantStub 开始,以 DeQuantStub 结束。但是如果最后一层的输出为 class_idx 等非量化数据(在 qat 模型中使用 Tensor 类型表示而没有使用 QTensor 的),则不需要 DeQuantStub。下方表格列出了作为模型输出层时不需要 DeQuantStub 的算子。
作为模型输出层时不需要 DeQuantStub 的算子 |
---|
torch.Tensor.argmax / torch.argmax |
horizon_plugin_pytorch.functional.argmax |
horizon_plugin_pytorch.functional.filter |
torch.max (返回值是 index 的部分无需反量化) |
4.2.3.1.3. 设置量化参数¶
通过对模型的 qconfig 属性赋值来指定模型的哪些部分需要被量化。例如,使用 model.conv1.qconfig = None 设置 model.conv 层不被量化。再比如使用 model1.linear1.qconfig = custom_qconfig 设置 model.linear1 会使用 custom_qconfig 而不使用全局的 qconfig。
4.2.3.1.4. 自定义浮点模型的例子¶
import torch
import torch.nn.quantized as nnq
import horizon_plugin_pytorch as horizon
from torch.quantization import QConfig, DeQuantStub
from torch import nn
from horizon_plugin_pytorch.quantization import (
fuse_known_modules
QuantStub
)
class ExampleNet(nn.Module):
def __init__(self):
super(ExampleNet, self).__init__()
self.quant = QuantStub()
self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
self.bn = nn.BatchNorm2d(num_features=1)
# add 操作必须使用 FloatFunctional
self.add = nn.quantized.FloatFunctional()
self.act = nn.ReLU()
self.out_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)
self.dequant = DeQuantStub()
def forward(self, x):
# 量化模型整体一般以 QuantStub 开始,以 DeQuantStub 结束
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
# 如果要将 add fuse 到 conv 中,add 的第一个输入必须来自于要 fuse 到的 conv
# 注意此处 add 的调用方式,使用 FloatFunctional 时必须使用它的具体方法,而不能直接使用 forward
x = self.add.add(x, y)
x = self.act(x)
x = self.out_conv(x)
x = self.dequant(x)
return x
def fuse_model(self):
from horizon_plugin_pytorch import quantization
torch.quantization.fuse_modules(
self,
["conv", "bn", "add", "act"],
inplace=True,
fuser_func=quantization.fuse_known_modules,
)
def set_qconfig(self):
# 这里可以不调用子模块的 set_qconfig 方法,没有 qconfig 的子模块会自动使用父模块的 qconfig
self.qconfig = horizon.quantization.get_default_qat_qconfig()
# 若网络最后输出层为 conv,可以单独设置为 out_qconfig 得到更高精度的输出
self.out_conv.qconfig = (
horizon.quantization.get_default_qat_out_qconfig()
)