模型的加载和保存

pytorch三种模型的加载保存操作

方法1 : PATH表示保存模型的路径和文件名

torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
class Model(nn.Module): def __init__(self, n_input_features): super(Model, self).__init__() self.linear = nn.Linear(n_input_features, 1) def forward(self, x): y_pred = torch.sigmoid(self.linear(x)) return y_pred model = Model(n_input_features=6) # train your medel... # save model FILE = "model.pth" torch.save(model, FILE) # load model model = torch.load(FILE) # 防止模型参数发生变化 model.eval() for param in model.parameters(): print(param) 

方法二:

保存模型时使用模型的state_dict()方法,加载模型前先实例化一个模型,然后调用load_state_dict()方法

torch.save(model.state_dict(), PATH)
# model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
class Model(nn.Module): def __init__(self, n_input_features): super(Model, self).__init__() self.linear = nn.Linear(n_input_features, 1) def forward(self, x): y_pred = torch.sigmoid(self.linear(x)) return y_pred model = Model(n_input_features=6) # train your medel... for param in model.parameters(): print(param) # save model FILE = "model.pth" torch.save(model.state_dict(), FILE) loaded_model = Model(n_input_features=6) loaded_model.load_state_dict(torch.load(FILE)) # 防止模型参数发生变化 loaded_model.eval() for param in loaded_model.parameters(): print(param)

方法三:

定义一个字典,保存多个参数到模型

class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your medel...

# print(model.state_dict())


learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# print(optimizer.state_dict())


checkpoint = {
    "epoch": 90,
    "model_state": model.state_dict(), "optim_state": optimizer.state_dict() } # 保存三种数据到模型 torch.save(checkpoint, "checkpoint.pth") # 加载模型 loaded_checkpoint = torch.load("checkpoint.pth") # 载入epcho数据 epoch = loaded_checkpoint['epoch'] print(epoch) # 定义模型和优化器 model = Model(n_input_features=6) optimizer = torch.optim.SGD(model.parameters(), lr=0) # 将保存的模型数据载入到模型和优化器中 model.load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optim_state"])

 推荐:什么是顺时针,你看到的是顺时针还是逆时针

猜你喜欢

转载自www.cnblogs.com/1994july/p/13192735.html