一、数据集准备与导入方法。
1、建立文件夹方式。(原始图片数据首选)
http://www.bubuko.com/infodetail-2304938.html
https://blog.csdn.net/u014380165/article/details/79058479
2、建立tensorDataset方式(文件格式数据)
http://www.pytorchtutorial.com/3-5-data-loader/
<注>
1)、不像上面1中不用管标签的命名,用2方法需要对标签进行标准化,不然在后面训练时会出现下面错误:https://blog.csdn.net/qq_27292549/article/details/82261040
def label_normal(train_labels): # 转换一下标签y_train中的标签名称:(2, 7, ... , 1500) ——> (0, 1, 2 , 3 ..., 750 );
train_labels_temp = []
for m1_index in train_labels:
if not m1_index in train_labels_temp:
train_labels_temp.append(m1_index)
print('train_labels_temp', train_labels_temp)
ki = 0
new_train_labels = []
for train_index in train_labels:
if train_index == train_labels_temp[ki]:
new_train_labels.append(ki)
else:
ki = ki + 1
new_train_labels.append(ki)
print('new_y_train', new_train_labels)
train_labels = new_train_labels
print('\n')
print('y_train_trans:', train_labels)
print('\n')
# 转换一下标签名称[结束]
train_labels = np.array(train_labels)
return train_labels
result_feature = scipy.io.loadmat('./evaluation/features/market_ResNet/market_ResNet_result.mat')
result_softout = scipy.io.loadmat('./evaluation/features/view_Branch_softmaxout/Market_softmaxout.mat')
# target label --train and val
train_labels = result_feature['train_label'][0]
train_labels = label_normal(train_labels) # 关键
train_labels = torch.LongTensor(train_labels)
train_labels = train_labels.view(-1,1) # reshape to fit the <torch.utils.data.TensorDataset()>
val_labels = result_feature['val_label'][0]
val_labels = label_normal(val_labels) # 关键
val_labels = torch.LongTensor(val_labels)
val_labels = val_labels.view(-1,1) # reshape to fit the <torch.utils.data.TensorDataset()>
image_datasets = {}
image_datasets['train'] = torch.utils.data.TensorDataset(r_f_train, train_yaw_back, train_labels)
image_datasets['val'] = torch.utils.data.TensorDataset(r_f_val, val_yaw_back, val_labels)
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=True, num_workers=16)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
3、建立dataset类
扫描二维码关注公众号,回复:
3601665 查看本文章
参考:https://blog.csdn.net/lqp888888/article/details/80481456
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, data, labels):
self.data= data
self.labels = labels
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
return img, target
def __len__(self):
return len(self.data)
二、预处理
1、使用torchvision中的transforms对自建数据进行预处理。
参考:https://pytorch.org/docs/0.4.0/_modules/torchvision/datasets/folder.html#DatasetFolder
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, data, labels, transforms = None):
self.data= data
self.labels = labels
self.transforms = transforms
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
img = self.transforms(img) # 关键位置
return img, target
def __len__(self):
return len(self.data)