Loading and visualization FashionMNIST
In this notebook, we want to load and view Fashion-MNIST database image.
The first step in any classification, is set to see the data you're using. This way you can learn some details about the image and label format, as well as some insight on how to define a network to identify such image set mode.
PyTorch There are some built-in data sets you can use, and FashionMNIST is one of them, it has been downloaded to a notebook in this data/
directory, so we need to do is to use FashionMNIST dataset class load these images and use the DataLoader
bulk load data.
Loading data
Dataset class and tensor
torch.utils.data.Dataset
Is an abstract class representing the data set, and FashionMNIST class is an extension of this data set class, which allows us to load the image / tag data volume, and uniformly applied to transform our data, for example, all images will be converted to use to train the neural network of the tensor. Tensor similar numpy arrays, may be used on the GPU, to speed up the calculation.
Let's take a look at how to build a training data set.
# our basic libraries import torch import torchvision # data loading and transforming from torchvision.datasets import FashionMNIST from torch.utils.data import DataLoader from torchvision import transforms # The output of torchvision datasets are PILImage images of range [0, 1]. # We transform them to Tensors for input into a CNN ## Define a transform to read the data in as a tensor data_transform = transforms.ToTensor() # choose the training and test datasets train_data = FashionMNIST(root='./data', train=True, download=False, transform=data_transform) # Print out some stats about the training data print('Train data, number of images: ', len(train_data))
Train data, number of images: 60000
Iteration and batch data
Next, we are going to use torch.utils.data.DataLoader
, it can be a batch process data and set iterator random data.
In the next cell, we scramble the data, and a size for the bulk of the load 20 image / data tag.
# prepare data loaders, set the batch_size ## TODO: you can try changing the batch_size to be larger or smaller ## when you get to training your network, see how batch_size affects the loss batch_size = 20 train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) # specify the image classes classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Some training data visualization
This cell will traverse the training data set, using dataiter.next()
load a random batch of image / data tag. Then, it will be 2 x batch_size/2
these visual image and label in the grid.
import numpy as np import matplotlib.pyplot as plt %matplotlib inline # obtain one batch of training images dataiter = iter(train_loader) images, labels = dataiter.next() images = images.numpy() # plot the images in the batch, along with the corresponding labels fig = plt.figure(figsize=(25, 4)) for idx in np.arange(batch_size): ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[]) ax.imshow(np.squeeze(images[idx]), cmap='gray') ax.set_title(classes[labels[idx]])
View the images in more detail
Each image data set is the 28x28
pixel and has been normalized grayscale image.
About normalized description
Normalization ensures that training in the CNN process, the feed has undergone the front and the reverse propagation step, wherein each image will fall within a similar range of values, rather than excessive activation of a particular layer of the network. During forward feeding step, the neural network receives the input image and each input pixel is multiplied by a number of filter weights convolution plus offset, and then applying some function to activate and pooled. Without normalization, compute the gradient back-propagation step will be very large and will lead to the loss of our increase, rather than convergence.
# select an image by index idx = 2 img = np.squeeze(images[idx]) # display the pixel values in that image fig = plt.figure(figsize = (12,12)) ax = fig.add_subplot(111) ax.imshow(img, cmap='gray') width, height = img.shape thresh = img.max()/2.5 for x in range(width): for y in range(height): val = round(img[x][y],2) if img[x][y] !=0 else 0 ax.annotate(str(val), xy=(y,x), horizontalalignment='center', verticalalignment='center', color='white' if img[x][y]<thresh else 'black')