谷歌colab之使用pytorch读取自己数据集(猫狗数据集)

之前在:https://www.cnblogs.com/xiximayou/p/12398285.html创建好了数据集,将它上传到谷歌colab

在colab上的目录如下:

在utils中的rdata.py定义了读取该数据集的代码:

from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torch
#预处理
transform = transforms.Compose([transforms.ToTensor()])
path = "/content/drive/My Drive/colab notebooks/data/dogcat"
train_path=path+"/train"
test_path=path+"/test"
#使用torchvision.datasets.ImageFolder读取数据集指定train和test文件夹
train_data = torchvision.datasets.ImageFolder(train_path, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=1)
 
test_data = torchvision.datasets.ImageFolder(test_path, transform=transform)
test_loader = DataLoader(test_data, batch_size=32, shuffle=True, num_workers=1)
print(train_data.classes)  #根据分的文件夹的名字来确定的类别
print(train_data.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(train_data.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

print(test_data.classes)  #根据分的文件夹的名字来确定的类别
print(test_data.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(test_data.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

ImageFolder可以读取我们的train或test下面的文件夹,并为每一个标签进行编码,同时将图片与标签进行对应。

在test.ipynb中运行rdata.py

说明我们创建的数据集是可以用的了。

有了数据集,接下来就是网络的搭建以及训练和测试了。 

猜你喜欢

转载自www.cnblogs.com/xiximayou/p/12422827.html