4.7.5.2. VargConvNet分类模型训练
这篇教程主要是告诉大家如何利用 HAT 在 ImageNet 上训练一个state-of-art的浮点模型。 ImageNet 是图像分类里用的最多的数据集,很多最先进的图像分类研究都会优先基于这个数据集做好验证。虽然有很多方法在社区或者其他途径里可以获取到 state-of-art 的分类模型,但从头训一个 state-of-art 的分类模型仍然不是一个简单的任务,本篇教程将会重点讲叙从数据集准备开始如何在 ImageNet 上训练出一个 state-of-art 的模型,包括浮点、量化和定点三种模式。
其中 ImageNet 数据集可以从这里下载:http://image-net.org/。下载之后的数据集格式为:
ILSVRC2012_img_train.tar ILSVRC2012_img_val.tar ILSVRC2012_devkit_t12.tar.gz
这里我们使用 VargConvNet 的例子来详细介绍整个分类的流程。
1. 训练流程
如果你只是想简单的使用 HAT 的接口来进行简单的实验,那么首先阅读一下这一小节的内容是个不错的选择。对于所有的训练和评测的任务, HAT 统一采用 tools + config 的形式来完成。在准备好原始数据集之后,通过以下的流程,我们可以方便的完成整个训练的流程。
数据集准备
首先是数据集打包,打包数据集与原始数据集在处理速度上有明显的优势,这里我们选择与 PyTorch 一脉相承的 LMDB 的打包方法,当然由于 HAT 在处理 dataset 上的灵活性,其他形式的数据集打包和读取形式,如 MXRecord 也是可以独立支持的。
在 tools/datasets 目录下提供了 cityscapes 、 imagenet 、 voc、 mscoco 这些常见数据集的打包脚本。例如 imagenet2lmdb 的脚本,可以利用torchvision 提供的默认公开数据集处理方法直接将原始的公开 ImageNet 数据集转成 Numpy 或者 Tensor 的格式,最后将得到的数据统一用 msgpack的方法压缩到 LMDB 的文件中。
这个过程可以很方便通过下面的脚本完成数据集打包:
python3 tools/datasets/imagenet_packer.py --src-data-dir ${src-data-dir} --target-data-dir ${target-data-dir} --split-name train --num-workers 10 --pack-type lmdb python3 tools/datasets/imagenet_packer.py --src-data-dir ${src-data-dir} --target-data-dir ${target-data-dir} --split-name val --num-workers 10 --pack-type lmdb
在完成数据集打包之后,可以得到含有 ImageNet 的 LMDB 数据集。下一步就可以开始训练。
模型训练
在网络训练开始之前,你可以使用以下命令先测试一下网络的计算量和参数数量:
python3 tools/calops.py --config configs/classification/vargconvnet.py --input-shape "1,3,224,224"
准备打包好数据集之后,便可以开始训练模型。只需要运行下面的命令就可以启动训练:
python3 tools/train.py --step "float" --config configs/classification/vargconvnet.py
如果想要验证已经训练好的模型精度,运行下面的命令即可:
python3 tools/train.py --step float --config configs/classification/vargconvnet.py --val-ckpt float-checkpoint-best.pth.tar --val-only
如果想要导出onnx模型, 运行下面的命令即可:
python3 tools/export_onnx.py --config configs/classification/vargconvnet.py --ckpt float-checkpoint-best.pth.tar --onnx-name vargconvnet.onnx
由于HAT算法包使用了一种巧妙的注册机制,使得每一个训练任务都可以按照这种train.py加上config配置文件的形式启动。train.py是统一的训练脚本,与任务无关,我们需要训练什么样的任务、使用什么样的数据集以及训练相关的超参数设置都在指定的config配置文件里面。config文件 里面提供了模型构建、数据读取等关键的dict。