数据来源:CalebA人脸数据集(官网链接)是香港中文大学的开放数据,包含10,177个名人身份的202,599张人脸图片,并且都做好了特征标记,这对人脸相关的训练是非常好用的数据集。共计40个特征,具体是哪些特征,可以去官网查询。话不多说,直接开始流程。
整个流程可以分为大致以下几个步骤:
1.图片预处理
2.构建网络
3.训练
4.测试
5.优化
一。图片加载,以为源数据没有经过处理,我们要重写torch.utils.data.Dataloader()处理图片,然后才能将图片用于加载。代码如下:
def default_loader(path): try: img = Image.open(path) return img.convert('RGB') except: print("Can not open {0}".format(path)) class myDataset(Data.DataLoader): def __init__(self,img_dir,img_txt=img_txt,transform=None,loader=default_loader): img_list = [] img_labels = [] fp = open(img_txt,'r') for line in fp.readlines(): if len(line.split())!=41: continue img_list.append(line.split()[0]) img_label_single = [] for value in line.split()[1:]: if value == '-1': img_label_single.append(0) if value == '1': img_label_single.append(1) img_labels.append(img_label_single) self.imgs = [os.path.join(img_dir,file) for file in img_list] self.labels = img_labels self.transform = transform self.loader = loader def __len__(self): return len(self.imgs) def __getitem__(self,index): img_path = self.imgs[index] label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64)) img = self.loader(img_path) if self.transform is not None: try: img = self.transform(img) except: print('Cannot transform image: {}'.format(img_path)) return img,label
图片增强、归一化处理和加载:
transform = transforms.Compose([ transforms.Resize(40), transforms.CenterCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5], std = [0.5,0.5,0.5]) ])
#训练集train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform)train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
#测试集
test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform) test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True)
构建网络:我使用的网络结构是每种属性使用3层卷积加上3层fc。网络结构比较简单,导致准确率不会有太高的表现,如果有兴趣可以做下优化,文末有优化的思路供大家讨论。好了,先上代码:
def make_conv(): return nn.Sequential( nn.Conv2d(3,16,3,1,1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16,32,3,1,1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32,64,3,1,1), nn.ReLU(), #nn.Dropout(0.5), nn.MaxPool2d(2) ) def make_fc(): return nn.Sequential( nn.Linear(64*4*4,128), nn.ReLU(), #nn.Dropout(0.5), nn.Linear(128,64), nn.ReLU(), nn.Dropout(0.5),#Dropout()可以一定程度上防止过拟合,放在不同位置或许会有意想不到的结果,有条件可以多尝试几次 nn.Linear(64,2) ) class face_attr(nn.Module): def __init__(self): super(face_attr,self).__init__() #attr0 self.attr0_layer1 = make_conv() self.attr0_layer2 = make_fc() #attr1 self.attr1_layer1 = make_conv() self.attr1_layer2 = make_fc() ...#每一中属性的计算都是相同的,在文中省略, #attr38 self.attr38_layer1 = make_conv() self.attr38_layer2 = make_fc() #attr39 self.attr39_layer1 = make_conv() self.attr39_layer2 = make_fc() def forward(self,x): out_list = [] #out0 out0 = self.attr0_layer1(x) out0 = out0.view(out0.size(0),-1) out0 = self.attr0_layer2(out0) out_list.append(out0) ... #out39 out39 = self.attr39_layer1(x) out39 = out39.view(out39.size(0),-1) out39 = self.attr39_layer2(out39) out_list.append(out39) return out_list
接下来就可以开始训练网络了,定义优化器的时候可以设置一下weight_decay=1e-8,也可以在一定程度上防止过拟合。
module = face_attr() #print(module) optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8) loss_list = [] for i in range(40): loss_func = nn.CrossEntropyLoss() loss_list.append(loss_func) #loss_func = nn.CrossEntropyLoss() for Epoch in range(50): all_correct_num = 0 for ii,(img,label) in enumerate(train_dataloader): img = Variable(img) label = Variable(label) output = module(img) optimizer.zero_grad() for i in range(40): loss = loss_list[i](output[i],label[:,i]) loss.backward() _,predict = torch.max(output[i],1) correct_num = sum(predict==label[:,i]) all_correct_num += correct_num.data[0] optimizer.step() Accuracy = all_correct_num *1.0/(len(train_dataset)*40.0) print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy)) torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#每跑一个epoch就保存一次模型
测试网络:和训练类似,只是不用优化和做反向传播。
module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl')#加载刚刚保存的网络 module.eval()#改成测试模式 all_correct_num = 0 for ii,(img,label) in enumerate(test_dataloader): img = Variable(img) label = Variable(label) output = module(img) for i in range(40): _,predict = torch.max(output[i],1) correct_num = sum(predict==label[:,i]) all_correct_num += correct_num.data[0] Accuracy = all_correct_num *1.0/(len(test_dataset)*40.0) print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy))
总结:我因为是笔记本电脑,没有GPU,所以只给了大约5000个数据用于训练(即使是5000个数据我的电脑也跑了2天才跑完50个epoch),1000个数据用于测试。测试的准确率在90%左右。有条件的同学可以做些优化,下面提供一些可以优化的方面:
1.图片增强,我因为电脑不给力,无法处理较大的数据,所以将原始图片缩放到40*40,然后截取了32*32作为输入,如果有GPU等条件,可以考虑128*128输入
2.最开始没加Dropout()的时候出现了过拟合的情况,当然这也和训练集较小有关系,建议训练集给到60000个样本。上述提到的Dropout()多尝试几个位置,我是放在了最后一层输出之前。期待大家尝试之后分享一下结果。
3.我现在每种属性都是使用相同的网络,但是这种网络有可能不是对于每种属性都是最优的选择,可以针对每一种属性单独写一层网络。例如:某个属性使用全fc层就可以达到很高的准确率,某个属性或许需要4层卷积+2层fc可以达到很好的效果,这种情况只有靠多尝试,算法什么的可能并不能给出哪种才是最适合的网络模型。建议输出每种属性的准确率,然后针对准确率较低的属性做相应的网络优化。
4.多准备几块GPU做训练和测试吧,没GPU真不给力。希望有小伙伴在这个网络的基础上能达到更高的准确率。
附上整体代码如下:
# -*- coding: utf-8 -*- """ Created on Sun Jun 17 11:54:36 2018 @author: sky-hole """ import torch import torch.nn as nn from torch.autograd import Variable import torch.optim as optim import torchvision.transforms as transforms import torch.utils.data as Data from PIL import Image import numpy as np import os img_root = 'W:/pic_data/face/CelebA/Img/img_align_celeba' train_txt = 'W:/pic_data/face/CelebA/Img/train10000.txt' batch_size = 2 def default_loader(path): try: img = Image.open(path) return img.convert('RGB') except: print("Can not open {0}".format(path)) class myDataset(Data.DataLoader): def __init__(self,img_dir,img_txt,transform=None,loader=default_loader): img_list = [] img_labels = [] fp = open(img_txt,'r') for line in fp.readlines(): if len(line.split())!=41: continue img_list.append(line.split()[0]) img_label_single = [] for value in line.split()[1:]: if value == '-1': img_label_single.append(0) if value == '1': img_label_single.append(1) img_labels.append(img_label_single) self.imgs = [os.path.join(img_dir,file) for file in img_list] self.labels = img_labels self.transform = transform self.loader = loader def __len__(self): return len(self.imgs) def __getitem__(self,index): img_path = self.imgs[index] label = torch.from_numpy(np.array(self.labels[index],dtype=np.int64)) img = self.loader(img_path) if self.transform is not None: try: img = self.transform(img) except: print('Cannot transform image: {}'.format(img_path)) return img,label transform = transforms.Compose([ transforms.Resize(40), transforms.CenterCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5], std = [0.5,0.5,0.5]) ]) train_dataset = myDataset(img_dir=img_root,img_txt=train_txt,transform= transform) train_dataloader = Data.DataLoader(train_dataset,batch_size = batch_size,shuffle=True) #print(len(train_dataset)) #print(len(train_dataloader)) def make_conv(): return nn.Sequential( nn.Conv2d(3,16,3,1,1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16,32,3,1,1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32,64,3,1,1), nn.ReLU(), #nn.Dropout(0.5), nn.MaxPool2d(2) ) def make_fc(): return nn.Sequential( nn.Linear(64*4*4,128), nn.ReLU(), #nn.Dropout(0.5), nn.Linear(128,64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64,2) ) class face_attr(nn.Module): def __init__(self): super(face_attr,self).__init__() #attr0 self.attr0_layer1 = make_conv() self.attr0_layer2 = make_fc() #attr1 self.attr1_layer1 = make_conv() self.attr1_layer2 = make_fc() #attr2 self.attr2_layer1 = make_conv() self.attr2_layer2 = make_fc() #attr3 self.attr3_layer1 = make_conv() self.attr3_layer2 = make_fc() #attr4 self.attr4_layer1 = make_conv() self.attr4_layer2 = make_fc() #attr5 self.attr5_layer1 = make_conv() self.attr5_layer2 = make_fc() #attr6 self.attr6_layer1 = make_conv() self.attr6_layer2 = make_fc() #attr7 self.attr7_layer1 = make_conv() self.attr7_layer2 = make_fc() #attr8 self.attr8_layer1 = make_conv() self.attr8_layer2 = make_fc() #attr9 self.attr9_layer1 = make_conv() self.attr9_layer2 = make_fc() #attr10 self.attr10_layer1 = make_conv() self.attr10_layer2 = make_fc() #attr11 self.attr11_layer1 = make_conv() self.attr11_layer2 = make_fc() #attr12 self.attr12_layer1 = make_conv() self.attr12_layer2 = make_fc() #attr13 self.attr13_layer1 = make_conv() self.attr13_layer2 = make_fc() #attr14 self.attr14_layer1 = make_conv() self.attr14_layer2 = make_fc() #attr15 self.attr15_layer1 = make_conv() self.attr15_layer2 = make_fc() #attr16 self.attr16_layer1 = make_conv() self.attr16_layer2 = make_fc() #attr17 self.attr17_layer1 = make_conv() self.attr17_layer2 = make_fc() #attr18 self.attr18_layer1 = make_conv() self.attr18_layer2 = make_fc() #attr19 self.attr19_layer1 = make_conv() self.attr19_layer2 = make_fc() #attr20 self.attr20_layer1 = make_conv() self.attr20_layer2 = make_fc() #attr21 self.attr21_layer1 = make_conv() self.attr21_layer2 = make_fc() #attr22 self.attr22_layer1 = make_conv() self.attr22_layer2 = make_fc() #attr23 self.attr23_layer1 = make_conv() self.attr23_layer2 = make_fc() #attr24 self.attr24_layer1 = make_conv() self.attr24_layer2 = make_fc() #attr25 self.attr25_layer1 = make_conv() self.attr25_layer2 = make_fc() #attr26 self.attr26_layer1 = make_conv() self.attr26_layer2 = make_fc() #attr27 self.attr27_layer1 = make_conv() self.attr27_layer2 = make_fc() #attr28 self.attr28_layer1 = make_conv() self.attr28_layer2 = make_fc() #attr29 self.attr29_layer1 = make_conv() self.attr29_layer2 = make_fc() #attr30 self.attr30_layer1 = make_conv() self.attr30_layer2 = make_fc() #attr31 self.attr31_layer1 = make_conv() self.attr31_layer2 = make_fc() #attr32 self.attr32_layer1 = make_conv() self.attr32_layer2 = make_fc() #attr33 self.attr33_layer1 = make_conv() self.attr33_layer2 = make_fc() #attr34 self.attr34_layer1 = make_conv() self.attr34_layer2 = make_fc() #attr35 self.attr35_layer1 = make_conv() self.attr35_layer2 = make_fc() #attr36 self.attr36_layer1 = make_conv() self.attr36_layer2 = make_fc() #attr37 self.attr37_layer1 = make_conv() self.attr37_layer2 = make_fc() #attr38 self.attr38_layer1 = make_conv() self.attr38_layer2 = make_fc() #attr39 self.attr39_layer1 = make_conv() self.attr39_layer2 = make_fc() def forward(self,x): out_list = [] #out0 out0 = self.attr0_layer1(x) out0 = out0.view(out0.size(0),-1) out0 = self.attr0_layer2(out0) out_list.append(out0) #out1 out1 = self.attr1_layer1(x) out1 = out1.view(out1.size(0),-1) out1 = self.attr1_layer2(out1) out_list.append(out1) #out2 out2 = self.attr2_layer1(x) out2 = out2.view(out2.size(0),-1) out2 = self.attr2_layer2(out2) out_list.append(out2) #out3 out3 = self.attr3_layer1(x) out3 = out3.view(out3.size(0),-1) out3 = self.attr3_layer2(out3) out_list.append(out3) #out4 out4 = self.attr4_layer1(x) out4 = out4.view(out4.size(0),-1) out4 = self.attr4_layer2(out4) out_list.append(out4) #out5 out5 = self.attr5_layer1(x) out5 = out5.view(out5.size(0),-1) out5 = self.attr5_layer2(out5) out_list.append(out5) #out6 out6 = self.attr6_layer1(x) out6 = out6.view(out6.size(0),-1) out6 = self.attr6_layer2(out6) out_list.append(out6) #out7 out7 = self.attr7_layer1(x) out7 = out7.view(out7.size(0),-1) out7 = self.attr7_layer2(out7) out_list.append(out7) #out8 out8 = self.attr8_layer1(x) out8 = out8.view(out8.size(0),-1) out8 = self.attr8_layer2(out8) out_list.append(out8) #out9 out9 = self.attr9_layer1(x) out9 = out9.view(out9.size(0),-1) out9 = self.attr9_layer2(out9) out_list.append(out9) #out10 out10 = self.attr10_layer1(x) out10 = out10.view(out10.size(0),-1) out10 = self.attr10_layer2(out10) out_list.append(out10) #out11 out11 = self.attr11_layer1(x) out11 = out11.view(out11.size(0),-1) out11 = self.attr11_layer2(out11) out_list.append(out11) #out12 out12 = self.attr12_layer1(x) out12 = out12.view(out12.size(0),-1) out12 = self.attr12_layer2(out12) out_list.append(out12) #out13 out13 = self.attr13_layer1(x) out13 = out13.view(out13.size(0),-1) out13 = self.attr13_layer2(out13) out_list.append(out13) #out14 out14 = self.attr14_layer1(x) out14 = out14.view(out14.size(0),-1) out14 = self.attr14_layer2(out14) out_list.append(out14) #out15 out15 = self.attr15_layer1(x) out15 = out15.view(out15.size(0),-1) out15 = self.attr15_layer2(out15) out_list.append(out15) #out16 out16 = self.attr16_layer1(x) out16 = out16.view(out16.size(0),-1) out16 = self.attr16_layer2(out16) out_list.append(out16) #out17 out17 = self.attr17_layer1(x) out17 = out17.view(out17.size(0),-1) out17 = self.attr17_layer2(out17) out_list.append(out17) #out18 out18 = self.attr18_layer1(x) out18 = out18.view(out18.size(0),-1) out18 = self.attr18_layer2(out18) out_list.append(out18) #out19 out19 = self.attr19_layer1(x) out19 = out19.view(out19.size(0),-1) out19 = self.attr19_layer2(out19) out_list.append(out19) #out20 out20 = self.attr20_layer1(x) out20 = out20.view(out20.size(0),-1) out20 = self.attr20_layer2(out20) out_list.append(out20) #out21 out21 = self.attr21_layer1(x) out21 = out21.view(out21.size(0),-1) out21 = self.attr21_layer2(out21) out_list.append(out21) #out22 out22 = self.attr22_layer1(x) out22 = out22.view(out22.size(0),-1) out22 = self.attr22_layer2(out22) out_list.append(out22) #out23 out23 = self.attr23_layer1(x) out23 = out23.view(out23.size(0),-1) out23 = self.attr23_layer2(out23) out_list.append(out23) #out24 out24 = self.attr24_layer1(x) out24 = out24.view(out24.size(0),-1) out24 = self.attr24_layer2(out24) out_list.append(out24) #out25 out25 = self.attr25_layer1(x) out25 = out25.view(out25.size(0),-1) out25 = self.attr25_layer2(out25) out_list.append(out25) #out26 out26 = self.attr26_layer1(x) out26 = out26.view(out26.size(0),-1) out26 = self.attr26_layer2(out26) out_list.append(out26) #out27 out27 = self.attr27_layer1(x) out27 = out27.view(out27.size(0),-1) out27 = self.attr27_layer2(out27) out_list.append(out27) #out28 out28 = self.attr28_layer1(x) out28 = out28.view(out28.size(0),-1) out28 = self.attr28_layer2(out28) out_list.append(out28) #out29 out29 = self.attr29_layer1(x) out29 = out29.view(out29.size(0),-1) out29 = self.attr29_layer2(out29) out_list.append(out29) #out30 out30 = self.attr30_layer1(x) out30 = out30.view(out30.size(0),-1) out30 = self.attr30_layer2(out30) out_list.append(out30) #out31 out31 = self.attr31_layer1(x) out31 = out31.view(out31.size(0),-1) out31 = self.attr31_layer2(out31) out_list.append(out31) #out32 out32 = self.attr32_layer1(x) out32 = out32.view(out32.size(0),-1) out32 = self.attr32_layer2(out32) out_list.append(out32) #out33 out33 = self.attr33_layer1(x) out33 = out33.view(out33.size(0),-1) out33 = self.attr33_layer2(out33) out_list.append(out33) #out34 out34 = self.attr34_layer1(x) out34 = out34.view(out34.size(0),-1) out34 = self.attr34_layer2(out34) out_list.append(out34) #out35 out35 = self.attr35_layer1(x) out35 = out35.view(out35.size(0),-1) out35 = self.attr35_layer2(out35) out_list.append(out35) #out36 out36 = self.attr36_layer1(x) out36 = out36.view(out36.size(0),-1) out36 = self.attr36_layer2(out36) out_list.append(out36) #out37 out37 = self.attr37_layer1(x) out37 = out37.view(out37.size(0),-1) out37 = self.attr37_layer2(out37) out_list.append(out37) #out38 out38 = self.attr38_layer1(x) out38 = out38.view(out38.size(0),-1) out38 = self.attr38_layer2(out38) out_list.append(out38) #out39 out39 = self.attr39_layer1(x) out39 = out39.view(out39.size(0),-1) out39 = self.attr39_layer2(out39) out_list.append(out39) return out_list module = face_attr() #print(module) optimizer = optim.Adam(module.parameters(),lr = 0.001,weight_decay=1e-8) loss_list = [] for i in range(40): loss_func = nn.CrossEntropyLoss() loss_list.append(loss_func) #loss_func = nn.CrossEntropyLoss() for Epoch in range(50): all_correct_num = 0 for ii,(img,label) in enumerate(train_dataloader): img = Variable(img) label = Variable(label) # optimizer.zero_grad() output = module(img) optimizer.zero_grad() for i in range(40): loss = loss_list[i](output[i],label[:,i]) loss.backward() _,predict = torch.max(output[i],1) correct_num = sum(predict==label[:,i]) all_correct_num += correct_num.data[0] optimizer.step() Accuracy = all_correct_num *1.0/(len(train_dataset)*40.0) print('Epoch ={0},all_correct_num={1},Accuracy={2}'.format(Epoch,all_correct_num,Accuracy)) torch.save(module,'W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl') ''' test_txt = 'W:/pic_data/face/CelebA/Img/test1000.txt' test_dataset = myDataset(img_dir=img_root,img_txt = test_txt,transform= transform) test_dataloader = Data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True) module = torch.load('W:/pic_data/face/CelebA/Img/face_attr40dro1.pkl') module.eval() all_correct_num = 0 for ii,(img,label) in enumerate(test_dataloader): img = Variable(img) label = Variable(label) output = module(img) for i in range(40): _,predict = torch.max(output[i],1) correct_num = sum(predict==label[:,i]) all_correct_num += correct_num.data[0] Accuracy = all_correct_num *1.0/(len(test_dataset)*40.0) print('all_correct_num={0},Accuracy={1}'.format(all_correct_num,Accuracy)) '''
欢迎分享和留言。
分享请注明出处。