pytorch(ch5

读取数据集::
# -*- coding: utf-8 -*-
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

class DogCat(data.Dataset):
def __init__(self,root):
imgs=os.listdir(root)
#所有图片的绝对路径
#这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
self.imgs=[os.path.join(root, img) for img in imgs]

def __getitem__(self, index):
img_path=self.imgs[index]
#dog->1, cat->0
label=1 if 'dog' in img_path.split("/")[-1] else 0
pil_img=Image.open(img_path)
array=np.asarray(pil_img)
data=t.from_numpy(array)
return data,label

def __len__(self):
return len(self.image)

dataset=DogCat('data/train')
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:
print(img.size(),img.float().mean(),label)

猜你喜欢

转载自www.cnblogs.com/shuimuqingyang/p/10309024.html
今日推荐