龙良曲pytorch学习笔记_加载宝可梦数据集

  1 import torch
  2 import os,glob
  3 import random,csv
  4 
  5 from torch.utils.data import Dataset,DataLoader
  6 
  7 from torchvision import transforms
  8 from PIL import Image
  9 
 10 class Pokemon(Dataset):
 11     '''
 12         @param
 13         root:存储的根路径
 14         resize:将图片大小根据网络结构适配
 15         mode:train或者test模式
 16     '''
 17     def __init__(self,root,resize,mode):
 18         super(Pokemon,self).__init__()
 19         
 20         self.root = root
 21         self.resize = resize
 22         
 23         # 字典类型key:name value:label
 24         self.name2label = {}
 25         # listdir返回顺序不固定,用sorted将它固定,因为排序一次之后就固定了
 26         for name in sorted(os.listdir(os.path.join(root))):
 27             if not os.path.isdir(os.path.join(root,name)):
 28                 continue
 29                 
 30             self.name2label[name] = len(self.name2label.keys())
 31         
 32         # print(self.name2label)
 33         
 34         # image_path + image_label
 35         self.images,self.labels = self.load_csv('images.csv')
 36         
 37         if mode == 'train': # 60%
 38             self.images = self.images[:int(0.6*len(self.images))]
 39             self.labels = self.labels[:int(0.6*len(self.labels))]
 40         elif mode == 'val': # 20%
 41             self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
 42             self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
 43         elif mode == 'test': # 20% = 80% ->100%
 44             self.images = self.images[int(0.8*len(self.images)):]
 45             self.labels = self.labels[int(0.8*len(self.labels)):]
 46             
 47     def load_csv(self,filename):
 48         
 49         # 如果不存在再写入,存在的话直接读取就可以了
 50         if not os.path.exists(os.path.join(self.root,filename))
 51             images = []
 52             for name in self.name2label.keys():
 53                 # 'pokemon'\\mewtwo\\00001.png
 54                 images += glob.glob(os.path.join(self.root,name,'*.png'))
 55                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
 56                 images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
 57                 
 58             # 1167,'pokemon\\bulbasaur\\00000000.png'
 59             print(len(images),images)
 60             
 61             random.shuffle(images)
 62             with open(os.path.join(self.root,filename),mode = 'w',newline='') as f:
 63                 writer = csv.writer(f)
 64                 for img in images: # 'pokemon\\bulbasaur\\00000000.png'
 65                     name = img.split(os.sep)[-2]
 66                     label = self.name2label[name]
 67                     # 'pokemon\\bulbasaur\\00000000.png',0
 68                     writer.writerow([img,label])
 69                 print('writen into csv file:',filename)
 70             
 71         # read from csv file
 72         images,labels = [],[]
 73         with open(os.path.join(self.root,filename))
 74             reader = csv.reader(f)
 75             for row in reader:
 76                 # 'pokemon\\bulbasaur\\00000000.png',0
 77                 img,label = row
 78                 label = int(label)
 79                 
 80                 images.append(img)
 81                 labels.append(label)
 82                 
 83         # 保证images和labels一一对应,长度相等
 84         assert len(images) == len(labels)
 85         return images,labels
 86             
 87     def __len__(self):
 88         
 89         return len(self.images)
 90         
 91     def denormalize(self,x_hat):
 92     
 93         mean=[0.485,0.456,0.406]
 94         std=[0.229,0.224,0.225]
 95         
 96         # x_hat = (x-mean)/std
 97         # x = x_hat*std+mean
 98         # x: [c,h,w]
 99         # mean: [3] --> [3,1,1]
100         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
101         std  = torch.tensor(std).unsqueeze(1).unsqueeze(1)
102         
103         x = x_hat*std + mean
104         
105         return x
106         
107     
108     def __getitem__(self,idx):
109         # idx~[0~len(images)]
110         # self.images,self.labels
111         # img: pokemon\\bulbasaur\\00000000.png'
112         # label: 0
113         img,label = self.images[idx],self.labels[idx]
114         
115         tf = transforms.Compose([
116             lambda x:Image.open(x).convert('RGB'), # string path --> image data
117             transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
118             transforms.RandomRotation(15),
119             transforms.CenterCrop(self.resize),
120             transforms.ToTensor(),
121             transforms.Normalize(mean=[0.485,0.456,0.406],
122                                  std=[0.229,0.224,0.225])
123         ])
124         
125         img = tf(img)
126         label = torch.tensor(label)
127         
128         return img,label

猜你喜欢

转载自www.cnblogs.com/fxw-learning/p/12331522.html
今日推荐