10.4.17. StereoNet双目深度估计模型训练¶
这篇教程主要是告诉大家如何利用HAT在数据集 SceneFlow
上从头开始训练一个 StereoNet
模型,包括浮点、量化和定点模型。
10.4.17.1. 数据集准备¶
在开始训练模型之前,第一步是需要准备好数据集,可以在 SceneFlow 数据集 下载,
同时需要准备训练数据和验证数据集对应的文件列表,可以从 此处 下载 SceneFlow_finalpass_train.txt
和 SceneFlow_finalpass_test.txt
。
下载后,解压并按照如下方式组织文件夹结构:
data
|-- SceneFlow
|-- Driving
|-- disparity
|-- frames_finalpass
|-- FlyingThings3D
|-- disparity
|-- frames_finalpass
|-- Monkaa
|-- disparity
|-- frames_finalpass
|-- SceneFlow_finalpass_test.txt
|-- SceneFlow_finalpass_train.txt
为了提升训练的速度,我们对数据信息文件做了一个打包,将其转换成lmdb格式的数据集。只需要运行下面的脚本,就可以成功实现转换:
python3 tools/datasets/sceneflow_packer.py --src-data-dir ${data-dir} --split-name train --pack-type lmdb --num-workers 10 --target-data-dir ${target-data-dir}
python3 tools/datasets/sceneflow_packer.py --src-data-dir ${data-dir} --split-name test --pack-type lmdb --num-workers 10 --target-data-dir ${target-data-dir}
上面这两条命令分别对应转换训练数据集和验证数据集,打包完成之后, ${target-data-dir}
目录下的文件结构应该如下所示:
${target-data-dir}
|-- train_lmdb
|-- test_lmdb
train_lmdb
和 test_lmdb
就是打包之后的训练数据集和验证数据集,接下来就可以开始训练模型。
10.4.17.1.1. 模型训练¶
在网络开始训练之前,你可以使用以下命令先计算一下网络的计算量和参数数量:
python3 tools/calops.py --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
下一步就可以开始训练。训练也可以通过下面的脚本来完成,在训练之前需要确认配置中数据集路径是否已经切换到已经打包好的数据集路径。
python3 tools/train.py --stage "float" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
python3 tools/train.py --stage "calibration" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
python3 tools/train.py --stage "qat" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
python3 tools/train.py --stage "int_infer" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
由于HAT算法包使用了注册机制,使得每一个训练任务都可以按照这种 train.py
加上 config
配置文件的形式启动。
train.py
是统一的训练脚本,与任务无关,我们需要训练什么样的任务、使用什么样的数据集以及训练相关的超参数设置都在指定的 config
配置文件里面。
上面的命令中 --stage
后面的参数可以是 "float"
、 "calibration"
、 "qat"
、 "int_infer"
,
分别可以完成浮点模型、量化模型的训练以及量化模型到定点模型的转化,
其中量化模型的训练依赖于上一步浮点训练产出的浮点模型,定点模型的转化依赖于量化训练产生的量化模型。
10.4.17.1.2. 模型验证¶
在完成训练之后,可以得到训练完成的浮点、量化或定点模型。和训练方法类似,
我们可以用相同方法来对训好的模型做指标验证,得到为 Float
、 Calibration
、 QAT
和 Quantized
的指标,分别为浮点、量化和完全定点的指标。
python3 tools/predict.py --stage "float" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
python3 tools/predict.py --stage "calibration" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
python3 tools/predict.py --stage "qat" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
python3 tools/predict.py --stage "int_infer" --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
和训练模型时类似, --stage
后面的参数为 "float"
、 "calibration"
、 "qat"
、 "int_infer"
时,分别可以完成对训练好的浮点模型、量化模型、定点模型的验证。
10.4.17.1.3. 模型推理¶
HAT
提供了 infer.py
脚本提供了对定点模型的推理结果进行可视化展示:
python3 tools/infer.py --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py --model-inputs imgl:${img1-path},imgr:${img2-path},baseline:${baseline},f:${f} --save-path ${save_path}
10.4.17.1.4. 仿真上板精度验证¶
除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:
python3 tools/align_bpu_validation.py --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
10.4.17.1.5. 定点模型检查和编译¶
在HAT中集成的量化训练工具链主要是为了地平线的计算平台准备的,因此,对于量化模型的检查和编译是必须的。
我们在HAT中提供了模型检查的接口,可以让用户定义好量化模型之后,先检查能否在 BPU
上正常运行:
python3 tools/model_checker.py --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
在模型训练完成后,可以通过 compile_perf
脚本将量化模型编译成可以上板运行的 hbm
文件,同时该工具也能预估在 BPU
上的运行性能:
python3 tools/compile_perf.py --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
以上就是从数据准备到生成量化可部署模型的全过程。
10.4.17.1.6. ONNX模型导出¶
如果想要导出onnx模型, 运行下面的命令即可:
python3 tools/export_onnx.py --config configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
10.4.17.2. 训练细节¶
在这个说明中,我们对模型训练需要注意的一些事项进行说明,主要为 config
的一些相关设置。
10.4.17.2.1. 模型构建¶
StereoNet
的网络结构可以参考 论文 ,这里不做详细介绍。
我们通过在 config
配置文件中定义 model
这样的一个 dict
型变量,就可以方便的实现对模型的定义和修改。
from torch import nn
loss_weights = [0.3, 0.3, 0.5, 0.5, 1.0]
maxdisp = 192
use_bn = True
bias = False
bn_kwargs = {}
refine_levels = 4
out_channels = [32, 32, 64, 128, 128, 16]
model = dict(
type="StereoNet",
backbone=dict(
type="StereoNetNeck",
out_channels=out_channels,
use_bn=use_bn,
bias=bias,
bn_kwargs=bn_kwargs,
act_type=nn.ReLU(),
),
head=dict(
type="StereoNetHead",
maxdisp=maxdisp,
bn_kwargs=bn_kwargs,
refine_levels=refine_levels,
),
post_process=dict(
type="StereoNetPostProcess",
maxdisp=maxdisp,
),
loss=dict(type="SmoothL1Loss"),
loss_weights=loss_weights,
)
模型除了 backbone
之外,还有 head
、 post_process
、 losses
模块,
在 StereoNet
中, backbone
主要是提取图像的特征, head
主要是由特征来得到预测的视差值。 post_process
主要是后处理部分, losses
模块采用论文中的 SmoothL1Loss
作为训练的 loss
, loss_weights
是对应的 loss
的权重。
10.4.17.2.2. 数据增强¶
跟 model
的定义一样,数据增强的流程是通过在 config
配置文件中定义 data_loader
和 val_data_loader
这两个 dict
来实现的,
分别对应着训练集和验证集的处理流程。以 data_loader
为例,数据增强使用了 RandomCrop
、 ToTensor
和 Normalize
来增加训练数据的多样性,增强模型的泛化能力。
因为最终跑在 BPU
上的模型使用的是 YUV444
的图像输入,而一般的训练图像输入都采用 RGB
的形式,
所以HAT提供 BgrToYuv444
的数据增强来将 RGB
转到 YUV444
的格式。
data_loader = dict(
type=torch.utils.data.DataLoader,
dataset=dict(
type="SceneFlow",
data_path="./tmp_data/SceneFlow/train_lmdb",
transforms=[
dict(
type="RandomCrop",
size=(256, 512),
),
dict(
type="ToTensor",
to_yuv=False,
use_yuv_v2=False,
),
dict(type="BgrToYuv444", rgb_input=True),
dict(
type="TorchVisionAdapter",
interface="Normalize",
mean=128.0,
std=128.0,
),
],
),
sampler=dict(type=torch.utils.data.DistributedSampler),
batch_size=train_batch_size_per_gpu,
pin_memory=True,
shuffle=False,
num_workers=data_num_workers,
collate_fn=collate_2d,
)
batch_processor
中传入一个 loss_collector
函数,用于获取当前批量数据的 loss
,如下所示:
def loss_collector(outputs: dict):
return outputs["losses"]
batch_processor = dict(
type="MultiBatchProcessor",
need_grad_update=True,
loss_collector=loss_collector,
)
验证集的数据转换相对简单很多,如下所示:
val_data_loader = dict(
type=torch.utils.data.DataLoader,
dataset=dict(
type="SceneFlow",
data_path="./tmp_data/SceneFlow/test_lmdb",
transforms=[
dict(
type="ToTensor",
to_yuv=False,
use_yuv_v2=False,
),
dict(type="BgrToYuv444", rgb_input=True),
dict(
type="TorchVisionAdapter",
interface="Normalize",
mean=128.0,
std=128.0,
),
],
),
sampler=dict(type=torch.utils.data.DistributedSampler),
batch_size=test_batch_size_per_gpu,
pin_memory=True,
shuffle=False,
num_workers=data_num_workers,
collate_fn=collate_2d,
)
val_batch_processor = dict(
type="MultiBatchProcessor",
need_grad_update=False,
loss_collector=None,
)
10.4.17.2.3. 训练策略¶
在 SceneFlow
数据集上训练浮点模型使用 Cosine
的学习策略配合 Warmup
,
以及对 weight
的参数施加 L2 norm。
configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
文件中的 float_trainer
, calibration_trainer
,
qat_trainer
, int_trainer
分别对应浮点、量化、定点模型的训练策略。
下面为 float_trainer
训练策略示例:
float_trainer = dict(
type="distributed_data_parallel_trainer",
model=model,
data_loader=data_loader,
optimizer=dict(
type=torch.optim.Adam,
params={"weight": dict(weight_decay=4e-5)},
lr=base_lr,
),
batch_processor=train_batch_processor,
stop_by="epoch",
num_epochs=num_epochs,
device=None,
sync_bn=True,
callbacks=[
stat_callback,
loss_show_callback,
dict(
type="CosLrUpdater",
warmup_by="epoch",
warmup_len=10,
step_log_interval=1000,
),
val_callback,
ckpt_callback,
],
train_metrics=[
dict(type="LossShow"),
dict(
type="EndPointError",
use_mask=True,
),
],
val_metrics=[
dict(
type="EndPointError",
use_mask=True,
),
],
)
10.4.17.2.4. 量化训练¶
关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,
请阅读 量化感知训练 章节的内容。
这里主要讲一下 HAT
的双目深度估计中如何定义和使用量化模型。
在模型准备的好情况下,包括量化已有的一些模块完成之后,HAT在训练脚本中统一使用下面的脚本将浮点模型映射到定点模型上来。
model.fuse_model()
model.set_qconfig()
horizon.quantization.prepare_qat(model, inplace=True)
量化训练的整体策略可以直接沿用浮点训练的策略,但学习率和训练长度需要适当调整。
因为有浮点预训练模型,所以量化训练的学习率 Lr
可以很小,一般可以从0.001或0.0001开始,并可以搭配 StepLrUpdater
做1-2次 scale=0.1
的 Lr
调整;
同时训练的长度不用很长。此外 weight decay
也会对训练结果有一定影响。
StereoNet
示例模型的量化训练策略可见 configs/disparity_pred/stereonet/stereonet_stereonetneck_sceneflow.py
文件。
10.4.17.2.5. 模型检查编译和仿真上板精度验证¶
对于HAT来说,量化模型的意义在于可以在 BPU
上直接运行。
因此,对于量化模型的检查和编译是必须的。前文提到的 compile_perf
脚本也可以让用户定义好量化模型之后,先检查能否在 BPU
上正常运行,
并可通过 align_bpu_validation
脚本获取模型上板精度。用法同前文。