torch.save(model.state_dict(), ‘best_model.pth‘)

torch.save(model.state_dict(), 'best_model.pth')is an operation that saves the model's parameter dictionary to a file. In this example, model.state_dict()the parameter dictionary for the current model is returned, and torch.save()the function to save the parameter dictionary "best_model.pth"to a file named .

This way is a common way to save the model, it saves the parameters of the model as a binary file, so that it can be loaded and restored when needed. The reason for saving the parameter dictionary instead of the entire model is that the structure and calculation graph of the model do not need to be saved, only the learnable parameters such as the weight and bias of the model need to be saved.

To load saved model parameters, the following code can be used:

model = MyModel()  # 创建一个与保存模型参数相同的模型实例
model.load_state_dict(torch.load('best_model.pth'))  # 加载参数字典到模型中

Through this loading method, the saved model parameters can be loaded into a model instance with the same structure as the original model, thereby restoring the parameters of the model. This saves the best model during training, or loads the model when needed and makes predictions or continues training.

model.state_dict()

model.state_dict()is a method that returns a dictionary of parameters for the model (state_dict). The parameter dictionary of the model is a Pythondictionary object that contains all the learnable parameters (weights and biases) of the model and their corresponding tensor values.

In deep learning, the parameters of a model usually consist of the weights and biases of the individual layers. model.state_dict()The method returns a dictionary where the key is the name of each parameter and the value is the tensor value of the corresponding parameter.

model.state_dict()The main purpose of using is to save and load the parameters of the model. model.state_dict()A dictionary of parameters for the current model can be obtained by calling , which can then be saved to a file. After saving the model's parameter dictionary, you can torch.load()load the dictionary with and load it into the model, thereby restoring the model's parameters.

Here is a sample code that demonstrates how to use parameters model.state_dict()to save and load models:

# 保存模型的参数
torch.save(model.state_dict(), 'model.pth')

# 加载模型的参数
model = MyModel()  # 创建一个空白模型
model.load_state_dict(torch.load('model.pth'))

This way, the parameters of the model can be loaded back from the saved file and applied to the new model instance. This is useful in scenarios such as saving and loading models during training, or migrating models between different sessions.

Guess you like

Origin blog.csdn.net/AdamCY888/article/details/131354334