7.4.7. 量化部署 PT 模型的跨设备 Inference 说明¶
量化部署的 pt 模型要求 trace 时使用的 device 和后续 infer 时使用的 device 一致。
若用户试图直接通过 to(device)
操作修改 pt 模型的 device,可能会出现模型 forward 报错的问题,torch 官方对此进行了解释,见 TorchScript-Frequently Asked Questions — PyTorch documentation。
下面举例说明:
import torch
class Net(torch.nn.Module):
def forward(self, x: torch.Tensor):
y = torch.ones(x.shape, device=x.device)
z = torch.zeros_like(x)
return y + z
script_mod = torch.jit.trace(
Net(), torch.rand(2, 3, 3, 3, device=torch.device("cpu"))
)
script_mod.to(torch.device("cuda"))
print(script_mod.graph)
# graph(%self : __torch__.Net,
# %x : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu)):
# %4 : int = prim::Constant[value=0]()
# %5 : int = aten::size(%x, %4)
# %6 : Long(device=cpu) = prim::NumToTensor(%5)
# %16 : int = aten::Int(%6)
# %7 : int = prim::Constant[value=1]()
# %8 : int = aten::size(%x, %7)
# %9 : Long(device=cpu) = prim::NumToTensor(%8)
# %17 : int = aten::Int(%9)
# %10 : int = prim::Constant[value=2]()
# %11 : int = aten::size(%x, %10)
# %12 : Long(device=cpu) = prim::NumToTensor(%11)
# %18 : int = aten::Int(%12)
# %13 : int = prim::Constant[value=3]()
# %14 : int = aten::size(%x, %13)
# %15 : Long(device=cpu) = prim::NumToTensor(%14)
# %19 : int = aten::Int(%15)
# %20 : int[] = prim::ListConstruct(%16, %17, %18, %19)
# %21 : NoneType = prim::Constant()
# %22 : NoneType = prim::Constant()
# %23 : Device = prim::Constant[value="cpu"]()
# %24 : bool = prim::Constant[value=0]()
# %y : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::ones(%20, %21, %22, %23, %24)
# %26 : int = prim::Constant[value=6]()
# %27 : int = prim::Constant[value=0]()
# %28 : Device = prim::Constant[value="cpu"]()
# %29 : bool = prim::Constant[value=0]()
# %30 : NoneType = prim::Constant()
# %z : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::zeros_like(%x, %26, %27, %28, %29, %30)
# %32 : int = prim::Constant[value=1]()
# %33 : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::add(%y, %z, %32)
# return (%33)
可以看到,在调用 to(torch.device("cuda"))
后,模型的 graph 中记录的 aten::ones
和 aten::zeros_like
的 device 参数仍为 prim::Constant[value="cpu"]()
,因此在模型 forward 时,它们的输出仍为 cpu Tensor。这是因为 to(device)
只能移动模型中的 buffer(weight、bias 等),无法修改 ScriptModule
的 graph。
torch 官方对以上限制给出的解决方案是,在 trace 前就确定好 pt 模型将要在哪个 device 上执行,并在对应的 device 上 trace 即可。
针对以上限制,训练工具建议根据具体场景选择以下解决方案:
7.4.7.1. PT 模型执行使用的 device 和 trace 不一致¶
对于可以确定 pt 模型将仅在 GPU 上执行,只需要修改卡号的情况,我们首先推荐使用 cuda:0
,即零号卡进行 trace。在使用模型时,用户可以通过 torch.cuda.set_device
接口,将物理上的任意卡映射为逻辑上的“零卡”,此时使用 cuda:0
trace 出的模型实际将在指定的物理卡上运行。
若 trace 时使用的 device 和执行时使用的 device 存在 CPU、GPU 的不一致,用户可以使用 horizon_plugin_pytorch.jit.to_device
接口实现 pt 模型的 device 迁移。此接口会寻找模型 graph 中的 device 参数,并将它们替换为需要的值。效果如下:
from horizon_plugin_pytorch.jit import to_device
script_mod = to_device(script_mod, torch.device("cuda"))
print(script_mod.graph)
# graph(%self : __torch__.Net,
# %x.1 : Tensor):
# %38 : bool = prim::Constant[value=0]()
# %60 : Device = prim::Constant[value="cuda"]()
# %34 : NoneType = prim::Constant()
# %3 : int = prim::Constant[value=0]()
# %10 : int = prim::Constant[value=1]()
# %17 : int = prim::Constant[value=2]()
# %24 : int = prim::Constant[value=3]()
# %41 : int = prim::Constant[value=6]()
# %4 : int = aten::size(%x.1, %3)
# %5 : Tensor = prim::NumToTensor(%4)
# %8 : int = aten::Int(%5)
# %11 : int = aten::size(%x.1, %10)
# %12 : Tensor = prim::NumToTensor(%11)
# %15 : int = aten::Int(%12)
# %18 : int = aten::size(%x.1, %17)
# %19 : Tensor = prim::NumToTensor(%18)
# %22 : int = aten::Int(%19)
# %25 : int = aten::size(%x.1, %24)
# %26 : Tensor = prim::NumToTensor(%25)
# %32 : int = aten::Int(%26)
# %33 : int[] = prim::ListConstruct(%8, %15, %22, %32)
# %y.1 : Tensor = aten::ones(%33, %34, %34, %60, %38)
# %z.1 : Tensor = aten::zeros_like(%x.1, %41, %3, %60, %38, %34)
# %50 : Tensor = aten::add(%y.1, %z.1, %10)
# return (%50)
7.4.7.2. 多卡并行推理¶
在此场景下,用户需要通过 trace 或 to_device
的方式取得 cuda:0
上的 pt 模型,并且为每块卡单独开启一个进程,通过 torch.cuda.set_device
的方式为每个进程设置不同的默认卡。一个简单的示例如下:
import os
import torch
import signal
import torch.distributed as dist
import torch.multiprocessing as mp
from horizon_plugin_pytorch.jit import to_device
model_path = "path_to_pt_model_file"
def main_func(rank, world_size, device_ids):
torch.cuda.set_device(device_ids[rank])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
model = to_device(torch.jit.load(model_path), torch.device("cuda"))
# 数据加载,模型 forward,精度计算等内容此处省略
def launch(device_ids):
try:
world_size = len(device_ids)
mp.spawn(
main_func,
args=(world_size, device_ids),
nprocs=world_size,
join=True,
)
# 当按下 Ctrl+c 时,关闭所有子进程
except KeyboardInterrupt:
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
launch([0, 1, 2, 3])
上述操作对 pt 模型的处理和 torch.nn.parallel.DistributedDataParallel
的做法一致,数据加载和模型精度计算相关内容请参考 Getting Started with Distributed Data Parallel — PyTorch Tutorials。