10.3.6. config 文件介绍¶
使用HAT算法包训练模型通常只需使用一条命令就可以了,即:
python3 tools/train.py --stage TRAINING_STEPS --config /PATH/TO/CONFIG
其中 /PATH/TO/CONFIG 就是模型训练对应的 config 文件,它负责定义了模型结构、数据集加载、以及整套的训练流程。
本章节通过介绍config文件中一些固定的全局的关键字,以及它们是如何配置的,让用户对config文件的内容以及作用有个大致的了解。
10.3.6.1. 全局关键字¶
training_stage:模型训练的各个阶段,包括
float、qat和int_infer。device_ids:模型训练使用的
gpu列表。cudnn_benchmark:是否打开cudnn benchmark。通常默认为
True。seed:是否设置随机数种子。通常默认为
None。log_rank_zero_only:简化多卡训练时的日志打印,只在第0卡上输出日志。通常默认为
True。model:参与
training过程中的模型结构。type表示模型的类型,如Classifier、Segmentor、RetinaNet等等,分别对应分类、分割、检测中的某一类模型。它会在使用过程中被build成具体的类,余下的参数都是用于初始化这个类。deploy_model:参与
deploy过程的模型结构,主要用于模型编译。和model相比,大多数情况下只需要把损失函数以及后处理部分设置为None即可。deploy_inputs:
deploy过程的模拟输入。不用关心具体的数值,只要保证格式满足输入要求即可。data_loader:训练阶段的数据集加载流程。它的
type是一个具体的类torch.utils.data.DataLoader,余下的参数都是用于初始化这个类。相关参数的含义也可以参考pytorch官网提供的接口文档。这里dataset表示读取某个具体的数据集,例如ImageNet、MSCOCO、VOC等等,它的transforms表示在数据读取过程中添加的数据增强操作。val_data_loader:验证模型性能阶段的数据集加载流程。和
data_loader不同的地方在于data_path不同,以及去掉了transforms的过程和sample的过程。batch_processor:模型在训练过程中每个迭代
stage进行的操作,包括前向计算、梯度回传、参数更新等等。如果包含batch_transforms参数,表示一些数据增强的操作是在gpu上进行的,这可以大大加快训练速度。val_batch_processor:模型在验证过程中每个迭代
stage进行的操作,只包含前向计算。metric_updater:模型训练过程中更新指标的方法,这个指标是用来验证训练的模型性能是否在提升。它通常是和
float_trainer下面的train_metrics配合着使用。train_metrics是具体的指标形式,metric_updater只是提供一种更新方法。val_metric_updater:训练出来的模型在验证性能的过程中更新指标的方法,这个指标用来验证最终训练出来的模型性能到底如何。它通常是和
float_trainer下面的val_metrics配合着使用,和metric_updater同理。float_trainer:浮点模型训练流程的配置。
type类型为distributed_data_parallel_trainer表示支持分布式训练,其余的参数分别定义了模型、数据集加载、优化器、训练epoch长度等等。其中callbacks表示训练过程中进行的一系列操作,比如模型保存、学习率更新、精度验证等等。 直接被tools/train.py文件调用的变量。qat_trainer:
qat模型训练流程配置。参数的含义和float_trainer基本一致。直接被tools/train.py文件调用的变量。int_infer_trainer:不包含训练流程,只是为了验证定点模型的精度。直接被
tools/train.py文件调用的变量。compile_cfg:编译相关的配置。
out_dir表示编译生成的hbm文件(部署模型)的输出路径。
之所以称这些变量为全局关键字,是因为几乎每个 config 文件中都定义了以上这些变量,且对应的功能基本一致。因此通过对这篇文档的学习,用户可以大致理解任意一个 config 文件实现的功能。
10.3.6.2. 如何配置¶
这里主要介绍数据类型为 dict 的全局关键字的配置。
数据类型为 dict 的全局关键字可以为两种:
包含
type的,例如model、data_loader、float_trainer等。不包含
type的,例如compile_cfg。
它们的区别在于包含 type 的全局关键字本质可以看作是一个 class,它的type值可以是一个 string 变量,也可以是一个具体的
class,如果是 string, 在程序运行中同样会被 build 成一个相应的 class。这个dict中除掉type之外的其它keys的值都用于初始化这个 class。和全局关键字属性类似,这些keys的值可以是一个数值,也可以是一个包含type变量的dict,例如 data_loader 中的 dataset 属性,以及这个 dataset 下面的 transforms 属性。
对于没有 type 变量的全局关键字来说,它就是一个普通类型的 dict 变量,代码在运行过程中会通过其 keys 获取对应的 values。
提示
所有已经提供的config配置,可以保证正常运行和复现精度。如果因为环境配置和训练时间等原因,需要修改配置的话,那么相对的训练策略可能也需要更改。直接修改config中的个别配置有时候并不能得到想要的结果。