Pytorch模型转Caffe

1. 支持的转换算子

github上实现的PytorchToCaffe的代码,支持转换的算子如下(参见:pytorch_to_caffe.py):

F.conv2d=Rp(F.conv2d,_conv2d)
F.linear=Rp(F.linear,_linear)
F.relu=Rp(F.relu,_relu)
F.leaky_relu=Rp(F.leaky_relu,_leaky_relu)
F.max_pool2d=Rp(F.max_pool2d,_max_pool2d)
F.avg_pool2d=Rp(F.avg_pool2d,_avg_pool2d)
F.adaptive_avg_pool2d = Rp(F.adaptive_avg_pool2d,_adaptive_avg_pool2d)
F.dropout=Rp(F.dropout,_dropout)
F.threshold=Rp(F.threshold,_threshold)
F.prelu=Rp(F.prelu,_prelu)
F.batch_norm=Rp(F.batch_norm,_batch_norm)
F.instance_norm=Rp(F.instance_norm,_instance_norm)
F.softmax=Rp(F.softmax,_softmax)
F.conv_transpose2d=Rp(F.conv_transpose2d,_conv_transpose2d)
F.interpolate = Rp(F.interpolate,_interpolate)
F.sigmoid = Rp(F.sigmoid,_sigmoid)
F.tanh = Rp(F.tanh,_tanh)
F.tanh = Rp(F.tanh,_tanh)
F.hardtanh = Rp(F.hardtanh,_hardtanh)
# F.l2norm = Rp(F.l2norm,_l2Norm)

torch.split=Rp(torch.split,_split)
torch.max=Rp(torch.max,_max)
torch.cat=Rp(torch.cat,_cat)
torch.div=Rp(torch.div,_div)
  • 作者重写了caffe的算子,来替换orch.nn算子。其中RP表示替换的意思(Replace)
  • 主要支持转Caffe的算子包括:F.conv2d,F.linear,F.relu,F.leaky_relu,F.max_pool2d,F.avg_pool2d,F.adaptive_avg_pool2d,F.dropout,F.threshold,F.prelu,F.batch_norm,F.instance_norm,F.softmax,F.conv_transpose2d,F.interpolate
  • F.upsampleF.interpolate算子不支持,经过测试上采样操作建议使用F.conv_transpose2d转置卷积替换。其中F.interpolate算子在转换caffe模型时,容易提示upsample_h参数不存在的错误(虽然作者代码中显示支持F.interpolate)。

2. pytoch转Caffe

  • (1) : github上下载PytorchToCaffe的脚本。
    在这里插入图片描述
  • (2): 将Caffe文件夹和pytorch_to_caffe.py文件放到项目根目录
  • (3): 对项目中不支持转caffe的算子,如upsampleF.interpolate,使用F.conv_transpose2d替换。
  • (4): 替换后重新训练pytorch模型,获得训练好的model.pt文件
  • (5): 在项目跟目录上创建convertCaffe.py,利用训练好的.pt文件,转caffe的.prototxt.caffemodel模型文件。convertCaffe.py的代码实现如下:
import sys
sys.path.insert(0,'.')
import torch
from torch.autograd import Variable
from torchvision.models import resnet
import pytorch_to_caffe
from nets.deeplabv3_plus import DeepLab

if __name__=='__main__':
    name = 'deeplab'
    model = DeepLab(8, backbone="mobilenet", downsample_factor=16, pretrained=False)
    #model.load_state_dict(torch.load('logs/best_epoch_weights.pth', map_location='cpu'))
    checkpoint = torch.load("logs/best_epoch_weights.pth")
    model.load_state_dict(checkpoint,False)
    model.eval()
    input=torch.ones([1,3,224,224])
     #input=torch.ones([1,3,224,224])
    pytorch_to_caffe.trans_net(model,input,name)
    pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

在这里插入图片描述
转换成功会提示Transform Completed

猜你喜欢

转载自blog.csdn.net/weixin_38346042/article/details/129692515