问题提出:用pytorch训练VGG16分类,loss从0.69下降到0.24就开始小幅度震荡,不管如何调整batch_size和learning_rate都无法解决。
原因:没有加载预训练模型
那么问题来了,官方给出的是1000类的ImageNet预训练模型 https://download.pytorch.org/models/vgg16-397923af.pth,而我要做的是20类数据集的分类,如何使用这一预训练的权重。
def vgg16(pretrained=False, **kwargs):
"""VGG 16-layer model (configuration "D")"""
model = VGG(make_layers(cfg['D']), **kwargs)
if pretrained:
model.load_state_dict(torch.load('./vgg16-397923af.pth'))
model.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, you_class_num),
)
return model
其中VGG按照官方给出的构造方法构造class VGG即可。
先构造1000类的VGG模型,用于加载pth预训练模型,然后重新构造分类层,将最后一层全连接层设置为需要的类别数量即可。