[ Pytorch ] —— 基本使用:(1)、数据集准备与导入 + 图片预处理

一、数据集准备与导入方法。

1、建立文件夹方式。(原始图片数据首选)

http://www.bubuko.com/infodetail-2304938.html 

https://blog.csdn.net/u014380165/article/details/79058479

2、建立tensorDataset方式(文件格式数据)

http://www.pytorchtutorial.com/3-5-data-loader/

<注>

    1)、不像上面1中不用管标签的命名,用2方法需要对标签进行标准化,不然在后面训练时会出现下面错误:https://blog.csdn.net/qq_27292549/article/details/82261040

def label_normal(train_labels):  # 转换一下标签y_train中的标签名称:(2, 7, ... , 1500) ——> (0, 1, 2 , 3 ..., 750 );
    train_labels_temp = []
    for m1_index in train_labels:
        if not m1_index in train_labels_temp:
            train_labels_temp.append(m1_index)
    print('train_labels_temp', train_labels_temp)

    ki = 0
    new_train_labels = []
    for train_index in train_labels:
        if train_index == train_labels_temp[ki]:
            new_train_labels.append(ki)
        else:
            ki = ki + 1
            new_train_labels.append(ki)
    print('new_y_train', new_train_labels)
    train_labels = new_train_labels
    print('\n')
    print('y_train_trans:', train_labels)
    print('\n')
    # 转换一下标签名称[结束]
    train_labels = np.array(train_labels)
    return train_labels

result_feature = scipy.io.loadmat('./evaluation/features/market_ResNet/market_ResNet_result.mat')
result_softout = scipy.io.loadmat('./evaluation/features/view_Branch_softmaxout/Market_softmaxout.mat')

# target label --train and val
train_labels = result_feature['train_label'][0]
train_labels = label_normal(train_labels)   # 关键
train_labels = torch.LongTensor(train_labels)
train_labels = train_labels.view(-1,1)   # reshape to fit the <torch.utils.data.TensorDataset()>
val_labels = result_feature['val_label'][0]
val_labels = label_normal(val_labels)  # 关键
val_labels = torch.LongTensor(val_labels)
val_labels = val_labels.view(-1,1)    # reshape to fit the <torch.utils.data.TensorDataset()>

image_datasets = {}
image_datasets['train'] = torch.utils.data.TensorDataset(r_f_train, train_yaw_back, train_labels)
image_datasets['val'] = torch.utils.data.TensorDataset(r_f_val, val_yaw_back, val_labels)

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                              shuffle=True, num_workers=16)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

3、建立dataset类

扫描二维码关注公众号,回复: 3601665 查看本文章

参考:https://blog.csdn.net/lqp888888/article/details/80481456

import torch.utils.data as data

class MyDataset(data.Dataset): 
    def __init__(self, data, labels): 
        self.data= data
        self.labels = labels  
    
    def __getitem__(self, index): 
        img, target = self.data[index], self.labels[index] 
        return img, target 
    def __len__(self):
        return len(self.data) 

二、预处理

1、使用torchvision中的transforms对自建数据进行预处理。

参考:https://pytorch.org/docs/0.4.0/_modules/torchvision/datasets/folder.html#DatasetFolder

import torch.utils.data as data

class MyDataset(data.Dataset): 
    def __init__(self, data, labels, transforms = None): 
        self.data= data
        self.labels = labels  
        self.transforms = transforms
    
    def __getitem__(self, index): 
        img, target = self.data[index], self.labels[index] 
        
        img = self.transforms(img) # 关键位置

        return img, target 
    def __len__(self):
        return len(self.data) 

猜你喜欢

转载自blog.csdn.net/jdzwanghao/article/details/82154184