Pytorch2.0中compiled_model=torch.compile(model) 的正确添加位置

今天pytorch官网更新了pytorch2.0稳定版,迫不及待的我直接更新了,确实像官方所说,只需加入model=torch.compile(model)一行代码即可加速,加入的位置如下。

cpu训练:

 model=UNet(deep_supervision=True)
 model=torch.compile(model)

单卡训练:

 model=UNet(deep_supervision=True)
 model.to(Device)
 model=torch.compile(model)

多卡训练:

 model=UNet(deep_supervision=True)
 model.to(Device)
 model=nn.parallel.DistributedDataParallel(
         model,
         device_ids=[local_rank],
         output_device=local_rank,
         broadcast_buffers=False,
     )
 model=torch.compile(model)

注意 model = torch.compile(model) 这句话的位置对了就可以了,其他的不用改!!

  1. 多卡训练官方教程:https://pytorch.org/docs/stable/notes/ddp.html#distributed-data-parallel

  1. torch.compile官方教程:https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html?utm_source=whats_new_tutorials&utm_medium=torch_compile

猜你喜欢

转载自blog.csdn.net/CarryEKAIruiui/article/details/129635139