Pytorch loaded and visualized FashionMNIST (Udacity)

Loading and visualization FashionMNIST


In this notebook, we want to load and view  Fashion-MNIST database image.

The first step in any classification, is set to see the data you're using. This way you can learn some details about the image and label format, as well as some insight on how to define a network to identify such image set mode.

PyTorch There are some built-in data sets you can use, and FashionMNIST is one of them, it has been downloaded to a notebook in this data/directory, so we need to do is to use FashionMNIST dataset class load these images and use the DataLoaderbulk load data.

Loading data

Dataset class and tensor

torch.utils.data.DatasetIs an abstract class representing the data set, and FashionMNIST class is an extension of this data set class, which allows us to load the image / tag data volume, and uniformly applied to transform our data, for example, all images will be converted to use to train the neural network of the tensor. Tensor similar numpy arrays, may be used on the GPU, to speed up the calculation.

Let's take a look at how to build a training data set.

# our basic libraries
import torch
import torchvision

# data loading and transforming
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision import transforms

# The output of torchvision datasets are PILImage images of range [0, 1]. 
# We transform them to Tensors for input into a CNN

## Define a transform to read the data in as a tensor
data_transform = transforms.ToTensor()

# choose the training and test datasets
train_data = FashionMNIST(root='./data', train=True,
                                   download=False, transform=data_transform)

# Print out some stats about the training data
print('Train data, number of images: ', len(train_data))

 Train data, number of images: 60000

Iteration and batch data

Next, we are going to use torch.utils.data.DataLoader, it can be a batch process data and set iterator random data.

In the next cell, we scramble the data, and a size for the bulk of the load 20 image / data tag.

# prepare data loaders, set the batch_size
## TODO: you can try changing the batch_size to be larger or smaller
## when you get to training your network, see how batch_size affects the loss
batch_size = 20

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# specify the image classes
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

 

Some training data visualization

This cell will traverse the training data set, using dataiter.next()load a random batch of image / data tag. Then, it will be 2 x batch_size/2these visual image and label in the grid.

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
    
# obtain one batch of training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
images = images.numpy()

# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(25, 4))
for idx in np.arange(batch_size):
    ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(images[idx]), cmap='gray')
    ax.set_title(classes[labels[idx]])

 

View the images in more detail

Each image data set is the 28x28pixel and has been normalized grayscale image.

About normalized description

Normalization ensures that training in the CNN process, the feed has undergone the front and the reverse propagation step, wherein each image will fall within a similar range of values, rather than excessive activation of a particular layer of the network. During forward feeding step, the neural network receives the input image and each input pixel is multiplied by a number of filter weights convolution plus offset, and then applying some function to activate and pooled. Without normalization, compute the gradient back-propagation step will be very large and will lead to the loss of our increase, rather than convergence.

# select an image by index
idx = 2
img = np.squeeze(images[idx])

# display the pixel values in that image
fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')
width, height = img.shape
thresh = img.max()/2.5
for x in range(width):
    for y in range(height):
        val = round(img[x][y],2) if img[x][y] !=0 else 0
        ax.annotate(str(val), xy=(y,x),
                    horizontalalignment='center',
                    verticalalignment='center',
                    color='white' if img[x][y]<thresh else 'black')

 

 

Guess you like

Origin www.cnblogs.com/wangyarui/p/11087918.html