PyTorch包括一个名为torchvision的包,用于加载和准备数据集。它包括两个基本函数,即Dataset和DataLoader,用于数据集的转换和加载。
**Dataset(数据集)**
数据集用于从给定的数据集中读取和转换数据点。实现的基本语法如下所示:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
**DataLoader(数据加载器)**
DataLoader用于对数据进行随机排列和分批处理。它可以与多进程工作器一起并行加载数据。
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
示例:加载CSV文件
我们使用Python包Pandas来加载CSV文件。原始文件具有以下格式:(图像名称,68个标记点 - 每个标记点有x、y坐标)。
landmarks_frame = pd.read_csv('faces/face_landmarks.csv')
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].values
landmarks = landmarks.astype('float').reshape(-1, 2)