如何冻结模型,避免 model.train() 改变模型部分模块

如何永久冻结参数

for name, param in self.llm_model.named_parameters():
	param.requires_grad = False

虽然以上代码可以冻结模型参数,但无法保证冻结整个模型。

model.train() 会导致部分仍处于激活状态。

重写 model.train 来保证 model.train() 并不会改变模型的任何模块:

def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self

model = your_model
freeze_model = True
if freeze_model:
	for _, param in self.llm_model.named_parameters():
		param.requires_grad = False
	model.eval()
	model.train = disabled_train

猜你喜欢

转载自blog.csdn.net/Friedrichor/article/details/132918826