caffe训练(6)生成solver.prototxt文件

使用python生成solver.prototxt文件

solver.prototxt文件中各个参数的具体含义可见博客:caffe总结(十)solver.prototxt参数含义
以分析的cifar10_quick_solver.prototxt文件为例,使用python程序,生成这个文件。

1.代码如下:

# -*- coding: UTF-8 -*-
import caffe                                                     #导入caffe包

def write_sovler():
    my_project_root = "D:/caffe-master/zzhld/"        #my-caffe-project目录
    sovler_string = caffe.proto.caffe_pb2.SolverParameter()                    #sovler存储
    solver_file = my_project_root + 'solver.prototxt'                        #sovler文件保存位置
    sovler_string.train_net = my_project_root + 'train.prototxt'            #train.prototxt位置指定
    sovler_string.test_net.append(my_project_root + 'test.prototxt')         #test.prototxt位置指定
    sovler_string.test_iter.append(100)                                        #测试迭代次数
    sovler_string.test_interval = 500                                        #每训练迭代test_interval次进行一次测试
    sovler_string.base_lr = 0.001                                            #基础学习率   
    sovler_string.momentum = 0.9                                            #动量
    sovler_string.weight_decay = 0.004                                        #权重衰减
    sovler_string.lr_policy = 'fixed'                                        #学习策略           
    sovler_string.display = 100                                                #每迭代display次显示结果
    sovler_string.max_iter = 4000                                            #最大迭代数
    sovler_string.snapshot = 4000                                             #保存临时模型的迭代数
    sovler_string.snapshot_format = 0                                        #临时模型的保存格式,0代表HDF5,1代表BINARYPROTO
    sovler_string.snapshot_prefix = 'examples/cifar10/cifar10_quick'        #模型前缀
    sovler_string.solver_mode = caffe.proto.caffe_pb2.SolverParameter.GPU    #优化模式

    with open(solver_file, 'w') as f:
        f.write(str(sovler_string))   

if __name__ == '__main__':
    write_sovler()
  • 特别注意的是:

    1. 上面代码首先需要更改路径,其余根据需要更改,不进行更改也可以运行出结果;

    2. 在编写路径时,我试验了几次必须要求斜杠是“/”,另外那个如果在windows中直接复制的话,会有问题。

2.运行结果:

在这里插入图片描述

训练模型

从第一篇笔记至此,我们已经了解到如何将jpg图片转换成Caffe使用的db(levelbd/lmdb)文件,如何计算数据均值,如何使用python生成solver.prototxt、train.prototxt、test.prototxt文件。接下来,就可以进行训练的最后一步,使用caffe提供的python接口训练生成模型。如果不进行可视化,只想得到一个最终的训练model,可以使用如下代码:

import caffe

my_project_root = "/home/Jack-Cui/caffe-master/my-caffe-project/"        #my-caffe-project目录
solver_file = my_project_root + 'solver.prototxt'                        #sovler文件保存位置
caffe.set_device(0)                                                      #选择GPU-0
caffe.set_mode_gpu()
solver = caffe.SGDSolver(solver_file)
solver.solve()

现在,如何训练生成模型的简单步骤已经讲完。接下来,以mnist实例,整合所学内容,训练生成model,并使用生成的model进行预测。


原文链接:https://blog.csdn.net/c406495762/article/details/70306728

发布了61 篇原创文章 · 获赞 15 · 访问量 963

猜你喜欢

转载自blog.csdn.net/weixin_42535423/article/details/103811695
今日推荐