Load some weights of MobileNetV2

1. Questions raised

        Considering the performance of edge devices during deployment, the network model should be smaller. Therefore, I wanted to use a simplified version of MobilenetV2. At the beginning, I did not load the pre-trained model because the network structure and number of layers were changed, but I found the model difficult to train. Later, I thought that since I just reduced the number of layers of the network, can I load the reserved part of the network weight? Maybe speed up the training of the network. ( Training does work later )

Two, the solution

        The pre-training model structure of MobilenetV2 is as follows. My model is that each block is only stacked once, that is, all n=1.

        Idea: Load the pretrained model pretrained_dict of MobilenetV2, instantiate your own model to get model_dict, compare the two state_dicts, and delete the key value in pretrained_dict.

def mobilenet_v2(pretrained=True):
    model = MobileNetV2(width_mult=1)
    if pretrained:
        # try:
        #    from torch.hub import load_state_dict_from_url
        # except ImportError:
        #    from torch.utils.model_zoo import load_url as load_state_dict_from_url
        # state_dict = load_state_dict_from_url(
        #     'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True)
        pretrained_dict = torch.load("mobilenetv2_1.0-f2a8633.pth.tar") # 预训练的MobilenetV2
        model_dict = model.state_dict()  # 读取自己的网络的结构参数
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if
                           k in model_dict and (v.shape == model_dict[k].shape)}
        model_dict.update(pretrained_dict)  # 将与 pretrained_dict 中 layer_name 相同的参数更新为 pretrained_dict 的参数
        model.load_state_dict(model_dict) # 加载更新后的参数
    return model

Guess you like

Origin blog.csdn.net/weixin_44855366/article/details/130553739