pytorch模型的储存和载入

pytorch保存和载入模型
1.相关函数

torch.save

torch.save(obj, f, pickle_module=pickle, pickle_protocol=2)

torch.load

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

map_location 选择加载到CPU或GPU中

# 保存在 CPU, 加载到 GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) 

# 保存在 GPU, 加载到 CPU

model.load_state_dict(torch.load(PATH, map_location='cpu'))

model.load_state_dict()
 

model.load_state_dict(state_dict, strict=True)

2.直接保存和加载

保存和加载整个模型 (已经训练完,无需继续训练) 占内存

# 保存
torch.save(model, PATH)
# 加载
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

3.使用state_dict保存加载(推荐)

使用state_dict只保留了权重参数,因此在加载时需要先初始化模型

否则会出现 pytorch AttributeError 报错

保存和加载 state_dict (已经训练完,无需继续训练)

保存

torch.save(model.state_dict(), PATH)

加载

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()   #一定要初始化  不然会报错

一般保存为.pt.pth 格式的文件。

1.load_state_dict()函数需要一个 dict 类型的输入,而不是保存模型的 PATH。所以这样 model.load_state_dict(PATH)是错误的,而应该model.load_state_dict(torch.load(PATH))
2.如果你想保存验证机上表现最好的模型,那么这样best_model_state=model.state_dict()是错误的。因为这属于浅复制,也就是说此时这个 best_model_state 会随着后续的训练过程而不断被更新,最后保存的其实是个 overfit 的模型。所以正确的做法应该是best_model_state=deepcopy(model.state_dict())。

 保存和加载 state_dict (没有训练完,还会继续训练)

保存

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...你自己的参数
            }, PATH)

加载

model = XIAOHU(*args, **kwargs)
optimizer = adam(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
...
model.eval()
# - or -
model.train()

猜你喜欢

转载自blog.csdn.net/qq_37925923/article/details/126919333