7.5.2. RGB888 数据部署¶
7.5.2.1. 场景¶
BPU 中图像金字塔的输出图像是 centered YUV444 的格式,其数据范围是 [-128, 127],但在训练阶段中,您的训练数据集有可能是 RGB 格式的,因此您需要对训练集的图片格式进行处理,避免出现训练的模型只能接受 RGB 的数据输入而无法正常上板推理的情况。通常,我们推荐您在训练时,在图像预处理阶段将 RGB 格式的图片转为 YUV 格式,与推理时 BPU 的数据流对齐。
由于编译器目前不支持颜色空间转换,用户可以手动插入颜色空间转换节点,从而绕过编译器的限制。
7.5.2.2. YUV 格式简介¶
YUV 一般用来描述模拟电视系统的颜色空间,在 BT.601 中 YUV 主要有两种制式:YUV studio swing(Y:16~235,UV:16~240)和 YUV full swing(YUV:0~255)。
BPU 支持的 YUV 格式是 full swing,因此在调用我们的工具中 YUV 的相关函数时,应确保指定了 full 作为 swing 格式。
7.5.2.3. 在训练时对 RGB 输入进行预处理¶
在训练时,您可以使用 horizon.functional.rgb2centered_yuv
或 horizon.functional.bgr2centered_yuv
将 RGB 图像转换为 BPU 所支持的 YUV 格式。以 rgb2centered_yuv
为例,该函数的定义如下:
def rgb2centered_yuv(input: Tensor, swing: str = "studio") -> Tensor:
"""Convert color space.
Convert images from RGB format to centered YUV444 BT.601
Args:
input: input image in RGB format, ranging 0~255
swing: "studio" for YUV studio swing (Y: -112~107,
U, V: -112~112)
"full" for YUV full swing (Y, U, V: -128~127).
default is "studio"
Returns:
output: centered YUV image
"""
函数输入为 RGB 图像,输出为 centered YUV 图像。其中,centered YUV 是指减去了 128 的偏置的 YUV 图像,这是 BPU 图像金字塔输出的标准图像格式。对于 full swing 而言,其范围应为 -128~127。您可以通过 swing
参数控制 full 和 studio 的取向。为了和 BPU 数据流格式对齐,请您将 swing
设为 “full”。
7.5.2.4. 在推理时对 YUV 输入进行实时转换¶
在任何情况下,我们都推荐您使用上述介绍的方案,即在训练时就将 RGB 图像转成 YUV 格式,这样可以避免在推理时引入额外的性能开销和精度损失。但如果您已经使用了 RGB 图像训练了模型,我们也提供了补救措施,通过在推理的时候在模型输入处插入颜色空间转换算子,将输入的 YUV 图像实时转换为 RGB 格式,从而支持 RGB 模型的上板部署,避免您重新训练模型给您带来时间成本和资源上的损失。由于该算子随模型运行在 BPU 上,底层采用定点运算实现,因而不可避免地会引入一定的精度损失,因此仅作为补救方案,请您尽可能按照我们所推荐的方式对数据进行处理。
7.5.2.4.1. 算子定义¶
您可以在推理模型的开头(QuantStub 的后面)插入 horizon.functional.centered_yuv2rgb
或 horizon.functional.centered_yuv2bgr
算子实现该功能。以 centered_yuv2rgb
为例,其定义为:
def centered_yuv2rgb(
input: QTensor,
swing: str = "studio",
mean: Union[List[float], Tensor] = (128.0,),
std: Union[List[float], Tensor] = (128.0,),
q_scale: Union[float, Tensor] = 1.0 / 128.0,
) -> QTensor:
swing
为 YUV 的格式,可选项为 “full” 和 “studio”。为了和 BPU 的 YUV 数据格式对齐,请您将 swing
设为 “full”。
mean
, std
均为您在训练时 RGB 图像所使用的归一化均值、标准差,支持 list 和 torch.Tensor 两种输入类型,支持单通道或三通道的归一化参数。如您的归一化均值为 [128, 0, -128] 时,您可以传入一个 [128., 0., -128.] 的 list 或 torch.tensor([128., 0., -128.])。
q_scale
为您在量化感知训练阶段所用的 QuantStub 的 scale 数值。支持 float 和 torch.Tensor 两种数据类型。
该算子完成了以下操作:
根据给定的
swing
所对应的转换公式将输入图像转换成 RGB 格式使用给定的
mean
和std
对 RGB 图像进行归一化使用给定的
q_scale
对 RGB 图像进行量化
由于该算子已经包括了对 RGB 图像的量化操作,因此在插入这个算子后您需要手动地将模型 QuantStub 的 scale 参数更改为 1。
插入该算子后的部署模型如下图所示:
注意
该算子为部署专用算子,请勿在训练阶段使用该算子。
7.5.2.4.2. 使用方法¶
在您使用 RGB 图像完成量化感知训练后,您需要:
获取量化感知训练时模型 QuantStub 所使用的 scale 值,以及 RGB 图像所使用的归一化参数;
调用
convert_fx
接口将 qat 模型转换为 quantized 模型;在模型的 QuantStub 后面插入
centered_yuv2rgb
算子,算子需要传入步骤 1 中所获取的参数;将 QuantStub 的
scale
参数修改成 1。
示例:
import torch
from horizon_plugin_pytorch.quantization import (
QuantStub,
prepare_qat_fx,
convert_fx,
)
from horizon_plugin_pytorch.functional import centered_yuv2rgb
from horizon_plugin_pytorch.quantization.qconfig import (
default_qat_8bit_fake_quant_qconfig,
)
from horizon_plugin_pytorch import March, set_march
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = QuantStub()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
def forward(self, input):
x = self.quant(input)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def set_qconfig(self):
self.qconfig = default_qat_8bit_fake_quant_qconfig
data = torch.rand(1, 3, 28, 28)
net = Net()
set_march(March.XXX)
net.set_qconfig()
qat_net = prepare_qat_fx(net)
qat_net(data)
quantized_net = convert_fx(qat_net)
traced = quantized_net
print("Before centered_yuv2rgb")
traced.graph.print_tabular()
# Replace QuantStub nodes with centered_yuv2rgb
patterns = ["quant"]
for n in traced.graph.nodes:
if any(n.target == pattern for pattern in patterns):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(centered_yuv2rgb, (n,), {"swing": "full"})
n.replace_all_uses_with(new_node)
new_node.args = (n,)
traced.quant.scale.fill_(1.0)
traced.recompile()
print("\nAfter centered_yuv2rgb")
traced.graph.print_tabular()
对比前后 Graph 可以看到修改后的图中插入了颜色空间转换节点:
Before centered_yuv2rgb
opcode name target args kwargs
----------- ------- -------- ---------- --------
placeholder input_1 input () {}
call_module quant quant (input_1,) {}
call_module conv conv (quant,) {}
output output output (conv,) {}
After centered_yuv2rgb
opcode name target args kwargs
------------- ---------------- --------------------------------------------- ------------------- -----------------
placeholder input_1 input () {}
call_module quant quant (input_1,) {}
call_function centered_yuv2rgb <function centered_yuv2rgb at 0x7fa1c2b48040> (quant,) {'swing': 'full'}
call_module conv conv (centered_yuv2rgb,) {}
output output output (conv,) {}