pytorch 载入历史模型时更换gpu卡号,map_location设置

pytorch 在保存训练后模型的时候,会把训练过程中使用的设备号(例如gpu卡号cuda:0cpu)也一并保存下来。当pytorch重新载入历史模型时,模型默认根据训练时的设备卡号,把权值载入到相应的卡号上。

然而,有的时候测试过程和训练过程的设备情况是不一致的。
举个例子,A主机有四块GPU卡,然后我们用cuda:3 训练模型,并保存模型。
在测试时候,我们需要在客户的B主机跑模型,但是B主机只有一块gpu卡:cuda:0

如果按照默认方式载入模型的话,pytorch会报找不到gpu设备,或其他一些错误。

此时,载入的时候需要做一个变换,为torch.load指定gpu设备的映射方式:

#模型的保存和加载
import torch
import json
import logger_wrappers
import os
destination_device='cuda:0'  #目标设备
model_CKPT = torch.load(path,map_location={
    
    'cuda:0':destination_device,'cuda:1':destination_device,'cuda:2':destination_device,'cuda:3':destination_device})

map_location 是一个dict,里面的key表示存储模型的源设备号,value表示目的设备号。
{ cuda:3:cuda:0} 表示把历史模型里面原来放在cuda:3 的权重全部加载到cuda:0 设备上。
{ cuda:3:cpu} 表示把历史模型里面原来放在cuda:3 的权重全部加载到cpu 设备。

猜你喜欢

转载自blog.csdn.net/jmh1996/article/details/111041108