10.3.3. 如何开启 AMP¶
AMP全称为Automatic Mixed Precision,即自动混合精度。AMP开启后,pytorch可以自动地在模型执行时将一些算子(如卷积和全连接)使用 float16
进行计算,以达到提升计算速度、减少显存占用的效果。详见 pytorch官方文档。
HAT中已经为AMP做好相关的工作,用户只需要在定义config文件中的 batch_processor
字段时将 enable_amp
参数设置为 True
即可。
注解
在模型验证时为得到准确的指标,一般是不需开启AMP的,在定义 val_batch_processor
字段时请将 enable_amp
参数设置为 False
。
# configs/example.py
# 使用 BasicBatchProcessor
batch_processor = dict(
type='BasicBatchProcessor',
need_grad_update=...,
batch_transforms=...,
enable_amp=True,
)
# 使用 MultiBatchProcessor
batch_processor = dict(
type="MultiBatchProcessor",
need_grad_update=...,
batch_transforms=...,
loss_collector=...,
enable_amp=True,
)