caffe分类学习及问题

1.分类实现过程

参考网站:https://blog.csdn.net/gaohuazhao/article/details/69568267

特别详细的一个教程……

2.训练自己的数据集遇到的问题

1)用python生成的train.txt和val.txt,对此,下一步用create_lmdb.sh生成lmdb训练文件

详细报错:io.cpp:80]Could not open or find file $caffe-master/examples/……*.jpg

错误原因:

①train.txt和val.txt中图片名与标签之间只能用空格,我之前采用的是4个空格;

参考网站:https://blog.csdn.net/dcxhun3/article/details/51966921

②data/train和data/val下面不能再有子文件夹,我之前下面还有2个子文件夹;

参考网站:https://groups.google.com/forum/#!topic/caffe-users/jO1s7g6Q84w

python生成train.txt和val.txt参考代码(图片存于一个文件夹,生成---训练:测试=8:2,随机产生)

# -*- coding: UTF-8 -*- 
import os 
import random
import shutil

def writeSpecificClassData(textFile, objList, objClass):
    if len(objList)!=len(objClass):
        raise Exception('the length of objList isn\'t same as objClass, please inspect your code.')
    for index in range(0,len(objList)):
        textFile.write(objList[index]+' ')
        if objClass[index].split('_')[1]=='normal':
            textFile.write('%d\n'%1)
        else:
            textFile.write('%d\n'%0)
            pass
    print(len(objList))

def genSpecificClassData(outputPath):
    if os.path.exists(outputPath) == False:  
        os.makedirs(outputPath)
    objList = []
    objClass = []
    for _, dirs, files in os.walk('$sliceDefect/'):
        for f in files:
            if os.path.splitext(f)[1]=='.JPG':
                path = _.split('/')[len(_.split('/'))-3]+'/'+os.path.splitext(f)[0]+'.jpg'
                classLabel = _.split('/')[len(_.split('/'))-2]+'_'+_.split('/')[len(_.split('/'))-1]
                if 'damper' in _.split('/')[len(_.split('/'))-2]:#damper is my specific class
                    if os.path.exists('$/caffe-master/examples/damper/data/sliceDefect/') == False:  
                        os.makedirs('$/caffe-master/examples/damper/data/sliceDefect/')
                    shutil.copy(os.path.join(_,f),os.path.join('$/caffe-master/examples/damper/data/sliceDefect/',os.path.splitext(f)[0]+'.jpg'))
                    objList.append(path)
                    objClass.append(classLabel)
                    # print(path)
    trainList = []
    trainLabel = []
    valList = []
    valLabel = []
    for index in range(0,len(objList)):
        if index%8 == 0:
            # print('index:',index)
            randomIndex1=random.randint(0,9)
            valList.append(objList[index+randomIndex1-9])
            valLabel.append(objClass[index+randomIndex1-9])
        elif index%9==0:
            randomIndex2=random.randint(0,9)
            if randomIndex2==randomIndex1:
                valList.append(objList[index+randomIndex2-10])
                valLabel.append(objClass[index+randomIndex2-10])
            else:
                valList.append(objList[index+randomIndex2-9])
                valLabel.append(objClass[index+randomIndex2-9])
        else:
            trainList.append(objList[index])
            trainLabel.append(objClass[index])
            pass
        
    print(len(objList),len(trainList),len(valList))
    trainFile = open(os.path.join(outputPath,'train.txt'), 'w')
    valFile = open(os.path.join(outputPath,'val.txt'), 'w')
    writeSpecificClassData(trainFile,trainList,trainLabel)
    writeSpecificClassData(valFile,valList,valLabel)

if __name__=='__main__':
    outputPath = './outputPath/'
    genSpecificClassData(outputPath)

猜你喜欢

转载自blog.csdn.net/lantuxin/article/details/80012111