pytorch 在保存训练后模型的时候,会把训练过程中使用的设备号(例如gpu卡号cuda:0
,cpu
)也一并保存下来。当pytorch重新载入历史模型时,模型默认根据训练时的设备卡号,把权值载入到相应的卡号上。
然而,有的时候测试过程和训练过程的设备情况是不一致的。
举个例子,A主机有四块GPU卡,然后我们用cuda:3
训练模型,并保存模型。
在测试时候,我们需要在客户的B主机跑模型,但是B主机只有一块gpu卡:cuda:0
。
如果按照默认方式载入模型的话,pytorch会报找不到gpu设备,或其他一些错误。
此时,载入的时候需要做一个变换,为torch.load指定gpu设备的映射方式:
#模型的保存和加载
import torch
import json
import logger_wrappers
import os
destination_device='cuda:0' #目标设备
model_CKPT = torch.load(path,map_location={
'cuda:0':destination_device,'cuda:1':destination_device,'cuda:2':destination_device,'cuda:3':destination_device})
map_location
是一个dict,里面的key表示存储模型的源设备号,value表示目的设备号。
{
cuda:3
:cuda:0
} 表示把历史模型里面原来放在cuda:3
的权重全部加载到cuda:0
设备上。
{
cuda:3
:cpu
} 表示把历史模型里面原来放在cuda:3
的权重全部加载到cpu
设备。