Kaggle比赛——Digit Recognizer——Part 1(Pytorch 数据集的建立)

       首先从Kaggle官网下载数据集https://www.kaggle.com/c/digit-recognizer/data里面包含三个CSV文档。train.csv是带标签的数据,用于训练和调参,test.csv是无标签的数据,在提交测试文档的时候才需要用到。

        这里,我先把train里面的数据又随机划分为两个表,一个用于训练一个用于交叉验证,代码很简单,主要是pandas的一些简单功能。

import numpy as np
import pandas as pd
from sklearn.model_selection  import train_test_split
#读取从kaggle上下载的训练集和测试集
train = pd.read_csv('train.csv')
from sklearn.model_selection import train_test_split
#train为数据集含有Feature和label.
train_set, val_set = train_test_split(train, test_size = 0.2)
train_set.to_csv('train_set.csv',index = False )
val_set.to_csv('val_set.csv',index = False )
print(train_set.shape)
print(train.shape)
print(val_set.shape)
运行结果为:
(33600, 785)
(42000, 785)
(8400, 785)

        这样,训练数据和交叉验证的数据就分别存在两个表里面了。

        下一步,我们需要重写Pytorch的数据集类,构建我们的数据集。

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable
#这里我自己定义了一个标准化处理函数,把图像的数据从0~255映射到-0.5~0.5
def data_tf(x):
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 # 标准化
    x = torch.from_numpy(x)
    return x
class MyMNIST(torch.utils.data.Dataset): #创建自己的类:MyMNIST,这个类是继承的torch.utils.data.Dataset
    def __init__(self, datatxt, transform=None, target_transform=None): #初始化一些需要传入的参数
        self.data = pd.read_csv(datatxt)       
        self.X = self.data.iloc[:,1:]
        self.X = np.array(self.X)
        self.y = self.data.iloc[:,0]
        self.y = np.array(self.y)
        self.transform = transform
        
    def __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        im = torch.tensor(self.X[index], dtype = torch.float)
        label = torch.tensor(self.y[index], dtype = torch.long )
        if self.transform is not None:
            im = self.transform(im)
        return im, label

    def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.data)
 
#利用数据类建立数据集
X_train = MyMNIST(datatxt = 'train_set.csv',transform = data_tf)
X_val = MyMNIST(datatxt= 'val_set.csv',transform = data_tf)
#train_data和val_data为可迭代对象,用于训练时分批读取数据
train_data = DataLoader(X_train,batch_size=64, shuffle=True)
val_data = DataLoader(X_val, batch_size=64, shuffle=False)

以上是构建数据集的全部代码。

下一步就可以构建网络结构。

猜你喜欢

转载自blog.csdn.net/qq_35654046/article/details/82252779
今日推荐