pytorch模型加载caffe模型的权重

一、将caffe模型的权重转成dict格式

caffe库的编译可以参考我之前写的一篇博客:ImportError: dynamic module does not define module export function (PyInit__caffe)问题解决记录_chen_zn95的博客-CSDN博客

安装好后使用以下脚本便可将caffe模型的参数名和参数保存成dict, 

import pickle as pkl
import caffe


MODEL_FILE = 'xxx.prototxt'
PRETRAIN_FILE = 'xxx.caffemodel'


if __name__ == '__main__':
    net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)
    name_weights = {}
    for param_name in net.params.keys():
        name_weights[param_name] = {}
        layer_params = net.params[param_name]
        if len(layer_params) == 1:
            weight = layer_params[0].data
            name_weights[param_name]['weight'] = weight
            print('%s:\n\t%s (weight)' % (param_name, weight.shape))
        elif len(layer_params) == 2:
            # weight
            weight = layer_params[0].data
            name_weights[param_name]['weight'] = weight
            # bias
            bias = layer_params[1].data
            name_weights[param_name]['bias'] = bias
            print('%s:\n\t%s (weight)' % (param_name, weight.shape))
            print('\t%s (bias)' % str(bias.shape))
        elif len(layer_params) == 3:
            # BN: running_mean, running_var, scale factor
            running_mean = layer_params[0].data  # running_mean
            name_weights[param_name]['running_mean'] = running_mean / layer_params[2].data
            running_var = layer_params[1].data  # running_var
            name_weights[param_name]['running_var'] = running_var/layer_params[2].data
            print('%s:\n\t%s (running_var)' % (param_name, running_var.shape),)
            print('\t%s (running_mean)' % str(running_mean.shape))
        else:
            raise RuntimeError("error\n")
    
    # save weight
    with open('weights.pkl', 'wb') as f:
        pkl.dump(name_weights, f, protocol=2)

二、pytorch模型加载dict格式的权重

这里有两个思路,一是根据权重名来匹配,二是根据权重的shape来匹配,但第二个方法有个问题,就是如果网络中有两个以上shape一样的层的话,那么根据权重的shape来匹配就会出错。下面分别介绍一下以上两个思路,

1、根据权重名匹配

这个方法比较繁琐,要求pytorch模型的参数名要与caffe模型的保持一致,如果不一致,则需要自己写个dict进行映射。具体操作如下,

import pickle as pkl
import torch
import copy


model = xxx
model1 = copy.deepcopy(model)

state_dict = {}
with open("weights.pkl", "rb") as wp:  # weights.pkl: 步骤一中生成的dict
    name_weights = pkl.load(wp, encoding='iso-8859-1')
    for key, value in name_weights.items():
        for k, v in value.items():
            state_dict[key + "." + k] = torch.from_numpy(v)
model1.load_state_dict(state_dict, strict=True)

另一种实现是直接对pytorch模型的参数赋值,代码如下,

import pickle as pkl
import torch
import copy


model = xxx
model2 = copy.deepcopy(model)

with open("weights.pkl", "rb") as wp:
    name_weights = pkl.load(wp, encoding='iso-8859-1')
    for name, param in model2.named_parameters():
        for key, value in name_weights.items():
            if name.split(".")[0] == key:
                for k, v in value.items():
                    if name.split(".")[1] == k:
                        param.data = torch.from_numpy(v)

2、根据权重shape匹配

import pickle as pkl
import torch
import copy


model = LightCNN_ir_eye()
model3 = copy.deepcopy(model)

with open("weights.pkl", "rb") as wp:
    name_weights = pkl.load(wp, encoding='iso-8859-1')
    for name, param in model3.named_parameters():
        for key, value in name_weights.items():
            for k, v in value.items():
                v = torch.from_numpy(v)
                if param.data.shape == v.shape:
                    if name == key + "." + k:  # 防止多个权重shape一致导致的错误
                        param.data = v

3、检查以上模型初始化方法是否正确

import cv2
import numpy as np
import torch


img = cv2.imread("xxx.jpg")
img = cv2.resize(img, (width, height))
img = np.tranpose(img, (2,0,1))
img = np.expand_dims(img, axis=0)

out1 = model1(torch.from_numpy(img).float())
out2 = model2(torch.from_numpy(img).float())
out3 = model3(torch.from_numpy(img).float())

print(out1)
print(out2)
print(out3)
for i in range(len(out1)):
    print(out1[i] == out2[i])
    print(out1[i] == out3[i])

猜你喜欢

转载自blog.csdn.net/qq_38964360/article/details/132118426