6.1.4.9. PwcNet光流预测模型训练¶
这篇教程主要是告诉大家如何利用 HAT 在光流数据集 FlyingChairs 上从头开始训练一个PwcNet 模型,包括浮点、量化和定点模型。
FlyingChairs 是光流预测中用的比较多的数据集,很多先进的光流预测研究都会优先基于这个数据集做好验证。开始训练模型之前,第一步是准备好数据集,这里我们下载官方的数据集FlyingChairs.zip 作为训练和验证集。同时需要下载相应的标签数据FlyingChairs_train_val.txt , 解压缩之后数据目录结构如下所示:
tmp_data |-- FlyingChairs |-- FlyingChairs_release |-- data |-- README.txt |-- FlyingChairs_train_val.txt |-- FlyingChairs.zip
6.1.4.9.1. 1. 训练流程¶
如果你只是想简单的把 PwcNet 的模型训练起来,那么可以首先阅读一下这一章的内容。 和其他任务一样,对于所有的训练,评测任务, HAT 统一采用tools + config 的形式来完成。在准备好原始数据集之后,可以通过下面的流程,方便地完成整个训练的流程。
6.1.4.9.1.1. 数据集准备¶
为了提升训练速度,我们对原始的数据集做了一个打包,将其转换为 LMDB 格式的数据集。只需要运行下面的脚本,就可以成功实现转换:
python3 tools/datasets/flyingchairs_packer.py --src-data-dir ${data-dir} --split-name train --pack-type lmdb --num-workers 10 python3 tools/datasets/flyingchairs_packer.py --src-data-dir ${data-dir} --split-name val --pack-type lmdb --num-workers 10
上面这两条命令分别对应转换训练数据集和验证数据集,打包完成之后,目录下的文件结构应该如下所示:
tmp_data |-- FlyingChairs |-- FlyingChairs_release |-- data |-- README.txt |-- FlyingChairs_train_val.txt |-- FlyingChairs.zip |-- train_lmdb |-- val_lmdb
train_lmdb 和 val_lmdb 就是打包之后的训练数据集和验证数据集,接下来就可以开始训练模型。
6.1.4.9.1.2. 模型训练¶
在网络开始训练之前,你可以使用以下命令先计算一下网络的计算量和参数数量:
python3 tools/calops.py --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py
下一步就可以开始训练。训练也可以通过下面的脚本来完成,在训练之前需要确认配置中数据集路径是否已经切换到已经打包好的数据集路径。
python3 tools/train.py --stage "float" --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py python3 tools/train.py --stage "qat" --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py python3 tools/train.py --stage "int_infer" --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py
由于 HAT 算法包使用了注册机制,使得每一个训练任务都可以按照这种train.py 加上 config 配置文件的形式启动。 train.py 是统一的训练脚本,与任务无关,我们需要训练什么样的任务、使用什么样的数据集以及训练相关的超参数设置都在指定的 config 配置文件里面。上面的命令中 –step 后面的参数可以是 “float” 、 “qat” 、 “int_infer”,分别可以完成浮点模型、量化模型的训练以及量化模型到定点模型的转化,其中量化模型的训练依赖于上一步浮点训练产出的浮点模型,定点模型的转化依赖于量化训练产生的量化模型,
6.1.4.9.1.3. 模型验证¶
在完成训练之后,可以得到训练完成的浮点、量化或定点模型。和训练方法类似,我们可以用相同方法来对训好的模型做指标验证,得到为 Float 、 QAT 和 Quantized 的指标,分别为浮点、量化和完全定点的指标。
python3 tools/predict.py --stage "float" --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py python3 tools/predict.py --stage "qat" --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py python3 tools/predict.py --stage "int_infer" --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py
和训练模型时类似,–step 后面的参数为 “float” 、 “qat” 、 “int_infer”时,分别可以完成对训练好的浮点模型、量化模型、定点模型的验证。
6.1.4.9.1.4. 模型推理¶
HAT 提供了 infer.py 脚本提供了对定点模型的推理结果进行可视化展示
python3 tools/infer.py --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py --dataset flyingchairs --input-size 384x512x6 --input-images ${img1-path},${img2-path} --input-format yuv --is-plot
6.1.4.9.1.5. 仿真上板精度验证¶
除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:
python3 tools/align_bpu_validation.py --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py --dataset flyingchairs
6.1.4.9.1.6. 定点模型检查和编译¶
在 HAT 中集成的量化训练工具链主要是为了地平线的芯片准备的,因此,对于量化模型的检查和编译是必须的。我们在 HAT 中提供了模型检查的接口,可以让用户定义好量化模型之后,先检查能否在 BPU 上正常运行:
python3 tools/model_checker.py --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py
在模型训练完成后,可以通过 compile_perf 脚本将量化模型编译成可以上板运行的 hbm 文件,同时该工具也能预估在 BPU 上 的运行性能:
python3 tools/compile_perf.py --config configs/opticalflow_pred/pwcnet/pwcnet_lg.py
以上就是从数据准备到生成量化可部署模型的全过程。
6.1.4.9.2. 1. 训练细节¶
在这个说明中,我们对模型训练需要注意的一些事项进行说明,主要为 config 的一些相关设置。
6.1.4.9.2.1. 模型构建¶
PwcNet 的网络结构可以参考 论文 和 社区TensorFlow版本 ,这里不做详细介绍。我们通过在 config 配置文件中定义 model 这样的一个 dict 型变量,就可以方便的实现对模型的定义和修改。
from torch import nn loss_weights = [0.005, 0.01, 0.02, 0.08, 0.32] out_channels = [16, 32, 64, 96, 128, 196] flow_pred_lvl = 2 pyr_lvls = 6 use_bn = True bn_kwargs = {} use_res = True use_dense = True model = dict( type="PwcnetTask", backbone=dict( type="PwcNet", out_channels=out_channels, use_bn=use_bn, bn_kwargs=bn_kwargs, pyr_lvls=pyr_lvls, flow_pred_lvl=flow_pred_lvl, act_type=nn.ReLU(), ), head=dict( type="PwcnetHead", in_channels=out_channels, bn_kwargs=bn_kwargs, use_bn=use_bn, md=4, use_res=use_res, use_dense=use_dense, pyr_lvls=pyr_lvls, flow_pred_lvl=flow_pred_lvl, act_type=nn.ReLU(), ), loss=dict(type="LnNormLoss", norm_order=2, power=1, reduction="mean"), loss_weights=loss_weights, )
模型除了 backbone 之外,还有 head`和 `losses 模块,在PwcNet中, backbone`主要是提取两张图像的特征, `head 主要是由特征来得到预测的光流图。 losses 部分采样论文中的LnNormLoss来作为训练的 loss, loss_weights`是特征层对应的 `loss 和权重。
6.1.4.9.2.2. 数据增强¶
跟 model 的定义一样,数据增强的流程是通过在 config 配置文件中定义 data_loader 和 val_data_loader 这两个 dict 来实现的,分别对应着训练集和验证集的处理流程。以 data_loader 为例,数据增强使用了 RandomCrop、 RandomFlip、SegRandomAffine 和 FlowRandomAffineScale。
data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="FlyingChairs", data_path="./tmp_data/FlyingChairs/train_lmdb/", transforms=[ dict( type="RandomCrop", size=(256, 448), ), dict( type="RandomFlip", px=0.5, py=0.5, ), dict( type="ToTensor", to_yuv=False, ), dict( type="SegRandomAffine", degrees=0, translate=(0.05, 0.05), scale=(1.0, 1.0), interpolation=InterpolationMode.BILINEAR, label_fill_value=0, translate_p=0.5, scale_p=0.0, ), dict( type="FlowRandomAffineScale", scale_p=0.5, scale_r=0.05, ), ], to_rgb=True, ), sampler=dict(type=torch.utils.data.DistributedSampler), batch_size=batch_size_per_gpu, pin_memory=True, shuffle=True, num_workers=4, collate_fn=hat.data.collates.collate_2d, )
因为最终跑在 BPU 上的模型使用的是 YUV444 的图像输入,而一般的训练图像输入都采用 RGB 的形式,所以 HAT 提供 BgrToYuv444 的数据增强来将 RGB 转到 YUV444 的格式。为了优化训练过程,HAT 使用了 batch_processor,可将一些增强处理放在 batch_processor 中优化训练:
def loss_collector(outputs: dict): return outputs["losses"] batch_processor = dict( type="MultiBatchProcessor", need_grad_update=True, batch_transforms=[ dict(type="BgrToYuv444", rgb_input=True), dict( type="TorchVisionAdapter", interface="Normalize", mean=128.0, std=128.0, ), dict( type="Scale", scales=tuple(1 / np.array(train_scales)), mode="bilinear", ), ], loss_collector=loss_collector, )
其中 loss_collector 是一个获取当前批量数据的 loss 的函数。
验证集的数据转换相对简单很多,如下所示:
val_data_loader = dict( type=torch.utils.data.DataLoader, dataset=dict( type="FlyingChairs", data_path="./tmp_data/FlyingChairs/val_lmdb/", transforms=[ dict( type="ToTensor", to_yuv=False, ), ], to_rgb=True, ), batch_size=batch_size_per_gpu, shuffle=False, num_workers=data_num_workers, pin_memory=True, collate_fn=hat.data.collates.collate_2d, )val_batch_processor = dict( type="MultiBatchProcessor", need_grad_update=False, batch_transforms=[ dict(type="BgrToYuv444", rgb_input=True), dict( type="TorchVisionAdapter", interface="Normalize", mean=128.0, std=128.0, ), ], loss_collector=None, )
6.1.4.9.2.3. 训练策略¶
在 FlyingChairs 数据集上训练浮点模型使用 Cosine 的学习策略配合 Warmup,以及对 weight 的参数施加 L2 norm。 configs/opticalflow_pred/pwcnet/pwcnet_lg.py 文件中的 float_trainer,qat_trainer, int_trainer 分别对应浮点、量化、定点模型的训练策略。下面以 float_trainer 训练策略示例:
float_trainer = dict( type="distributed_data_parallel_trainer", model=model, data_loader=data_loader, optimizer=dict( type=torch.optim.Adam, params={"weight": dict(weight_decay=4e-4)}, lr=lr, ), batch_processor=batch_processor, stop_by="epoch", num_epochs=max_epoch, device=None, callbacks=[ stat_callback, loss_metirc_show_update, dict( type="CosLrUpdater", warmup_by="epoch", warmup_len=10, step_log_interval=1000, ), val_callback, ckpt_callback, ], train_metrics=[ dict(type="LossShow"), dict(type="EndPointError"), ], val_metrics=[ dict(type="EndPointError"), ], sync_bn=True, )
6.1.4.9.2.4. 量化训练¶
关于量化训练中的关键步骤,比如准备浮点模型、算子替换、插入量化和反量化节点、设置量化参数以及算子的融合等,请阅读 《Horizon Plugin PyTorch》 手册中的浮点模型准备 和 算子融合 两节中的内容。这里主要讲一下 HAT的光流预测中如何定义和使用量化模型。
在模型准备的好情况下,包括量化已有的一些模块完成之后, HAT 在训练脚本中统一使用下面的脚本将浮点模型映射到定点模型上来。
model.fuse_model() model.set_qconfig() horizon.quantization.prepare_qat(model, inplace=True)
量化训练的整体策略可以直接沿用浮点训练的策略,但学习率和训练长度需要适当调整。因为有浮点预训练模型,所以量化训练的学习率 Lr 可以很小,一般可以从 0.001 或 0.0001 开始,并可以搭配 StepLrUpdater 做 1-2 次scale=0.1 的 Lr 调整;同时训练的长度不用很长。此外 weight decay 也会对训练结果有一定影响。
PwcNet`示例模型的量化训练策略可见 `configs/opticalflow_pred/pwcnet/pwcnet_lg.py 文件。
6.1.4.9.2.5. 模型检查编译和仿真上板精度验证¶
对于 HAT 来说,量化模型的意义在于可以在 BPU 上直接运行。因此,对于量化模型的检查和编译是必须的。前文提到的 compile_perf 脚本也可以让用户定义好量化模型之后,先检查能否在 BPU 上正常运行,并可通过align_bpu_validation 脚本获取模型上板精度。用法同前文。