Keras卷积神经网络识别CIFAR—10图像


后期文章陆续登在公众号

在这里插入图片描述
GITHUB地址https://github.com/fz861062923/Keras/new/master

下载数据

读取CIFAR-10数据

y代表label,x代表image

from keras.datasets import cifar10
import numpy as np
(x_train,y_train),\
(x_test,y_test)=cifar10.load_data()
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 384s 2us/step

查看训练数据

不同的label对应着不同的图像,0:airplane,1:automobile,2:bird,3:cat,4:deer,5:dog,6:frog,7:horse,8:ship,9:truck

shape一下,看看数据的组成结构

x_train.shape#image的结构
(50000, 32, 32, 3)
y_train.shape#label的结构
(50000, 1)

查看第一张image的内容

x_train[0]
array([[[ 59,  62,  63],
        [ 43,  46,  45],
        [ 50,  48,  43],
        ...,
        [158, 132, 108],
        [152, 125, 102],
        [148, 124, 103]],

       [[ 16,  20,  20],
        [  0,   0,   0],
        [ 18,   8,   0],
        ...,
        [123,  88,  55],
        [119,  83,  50],
        [122,  87,  57]],

       [[ 25,  24,  21],
        [ 16,   7,   0],
        [ 49,  27,   8],
        ...,
        [118,  84,  50],
        [120,  84,  50],
        [109,  73,  42]],

       ...,

       [[208, 170,  96],
        [201, 153,  34],
        [198, 161,  26],
        ...,
        [160, 133,  70],
        [ 56,  31,   7],
        [ 53,  34,  20]],

       [[180, 139,  96],
        [173, 123,  42],
        [186, 144,  30],
        ...,
        [184, 148,  94],
        [ 97,  62,  34],
        [ 83,  53,  34]],

       [[177, 144, 116],
        [168, 129,  94],
        [179, 142,  87],
        ...,
        [216, 184, 140],
        [151, 118,  84],
        [123,  92,  72]]], dtype=uint8)

可视化image

建立label与image一一对应的关系字典

dict={0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'}
import matplotlib.pyplot as plt
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 14)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)
        
        ax.imshow(images[idx],cmap='binary')           
        title= str(i)+'.'+dict[labels[i][0]]
        if len(prediction)>0:
            title+='=>'+dict[prediction[i]] 
            
        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()#使用的还是MNIST数据集的可视化函数,只是改变了title的定义,使其看得更明了

查看前面十个图片

plot_images_labels_prediction(x_train,y_train,prediction=[],idx=0)

png

将images进行预处理

对数据进行标准化,和处理MINIST数据集的思想是一样的

x_train_normalize=x_train/255.0
x_test_normalize=x_test/255.0

查看标准化后的结果

x_train_normalize
array([[[[0.23137255, 0.24313725, 0.24705882],
         [0.16862745, 0.18039216, 0.17647059],
         [0.19607843, 0.18823529, 0.16862745],
         ...,
         [0.61960784, 0.51764706, 0.42352941],
         [0.59607843, 0.49019608, 0.4       ],
         [0.58039216, 0.48627451, 0.40392157]],

        [[0.0627451 , 0.07843137, 0.07843137],
         [0.        , 0.        , 0.        ],
         [0.07058824, 0.03137255, 0.        ],
         ...,
         [0.48235294, 0.34509804, 0.21568627],
         [0.46666667, 0.3254902 , 0.19607843],
         [0.47843137, 0.34117647, 0.22352941]],

        [[0.09803922, 0.09411765, 0.08235294],
         [0.0627451 , 0.02745098, 0.        ],
         [0.19215686, 0.10588235, 0.03137255],
         ...,
         [0.4627451 , 0.32941176, 0.19607843],
         [0.47058824, 0.32941176, 0.19607843],
         [0.42745098, 0.28627451, 0.16470588]],

        ...,

        [[0.81568627, 0.66666667, 0.37647059],
         [0.78823529, 0.6       , 0.13333333],
         [0.77647059, 0.63137255, 0.10196078],
         ...,
         [0.62745098, 0.52156863, 0.2745098 ],
         [0.21960784, 0.12156863, 0.02745098],
         [0.20784314, 0.13333333, 0.07843137]],

        [[0.70588235, 0.54509804, 0.37647059],
         [0.67843137, 0.48235294, 0.16470588],
         [0.72941176, 0.56470588, 0.11764706],
         ...,
         [0.72156863, 0.58039216, 0.36862745],
         [0.38039216, 0.24313725, 0.13333333],
         [0.3254902 , 0.20784314, 0.13333333]],

        [[0.69411765, 0.56470588, 0.45490196],
         [0.65882353, 0.50588235, 0.36862745],
         [0.70196078, 0.55686275, 0.34117647],
         ...,
         [0.84705882, 0.72156863, 0.54901961],
         [0.59215686, 0.4627451 , 0.32941176],
         [0.48235294, 0.36078431, 0.28235294]]],


       [[[0.60392157, 0.69411765, 0.73333333],
         [0.49411765, 0.5372549 , 0.53333333],
         [0.41176471, 0.40784314, 0.37254902],
         ...,
         [0.35686275, 0.37254902, 0.27843137],
         [0.34117647, 0.35294118, 0.27843137],
         [0.30980392, 0.31764706, 0.2745098 ]],

        [[0.54901961, 0.62745098, 0.6627451 ],
         [0.56862745, 0.6       , 0.60392157],
         [0.49019608, 0.49019608, 0.4627451 ],
         ...,
         [0.37647059, 0.38823529, 0.30588235],
         [0.30196078, 0.31372549, 0.24313725],
         [0.27843137, 0.28627451, 0.23921569]],

        [[0.54901961, 0.60784314, 0.64313725],
         [0.54509804, 0.57254902, 0.58431373],
         [0.45098039, 0.45098039, 0.43921569],
         ...,
         [0.30980392, 0.32156863, 0.25098039],
         [0.26666667, 0.2745098 , 0.21568627],
         [0.2627451 , 0.27058824, 0.21568627]],

        ...,

        [[0.68627451, 0.65490196, 0.65098039],
         [0.61176471, 0.60392157, 0.62745098],
         [0.60392157, 0.62745098, 0.66666667],
         ...,
         [0.16470588, 0.13333333, 0.14117647],
         [0.23921569, 0.20784314, 0.22352941],
         [0.36470588, 0.3254902 , 0.35686275]],

        [[0.64705882, 0.60392157, 0.50196078],
         [0.61176471, 0.59607843, 0.50980392],
         [0.62352941, 0.63137255, 0.55686275],
         ...,
         [0.40392157, 0.36470588, 0.37647059],
         [0.48235294, 0.44705882, 0.47058824],
         [0.51372549, 0.4745098 , 0.51372549]],

        [[0.63921569, 0.58039216, 0.47058824],
         [0.61960784, 0.58039216, 0.47843137],
         [0.63921569, 0.61176471, 0.52156863],
         ...,
         [0.56078431, 0.52156863, 0.54509804],
         [0.56078431, 0.5254902 , 0.55686275],
         [0.56078431, 0.52156863, 0.56470588]]],


       [[[1.        , 1.        , 1.        ],
         [0.99215686, 0.99215686, 0.99215686],
         [0.99215686, 0.99215686, 0.99215686],
         ...,
         [0.99215686, 0.99215686, 0.99215686],
         [0.99215686, 0.99215686, 0.99215686],
         [0.99215686, 0.99215686, 0.99215686]],

        [[1.        , 1.        , 1.        ],
         [1.        , 1.        , 1.        ],
         [1.        , 1.        , 1.        ],
         ...,
         [1.        , 1.        , 1.        ],
         [1.        , 1.        , 1.        ],
         [1.        , 1.        , 1.        ]],

        [[1.        , 1.        , 1.        ],
         [0.99607843, 0.99607843, 0.99607843],
         [0.99607843, 0.99607843, 0.99607843],
         ...,
         [0.99607843, 0.99607843, 0.99607843],
         [0.99607843, 0.99607843, 0.99607843],
         [0.99607843, 0.99607843, 0.99607843]],

        ...,

        [[0.44313725, 0.47058824, 0.43921569],
         [0.43529412, 0.4627451 , 0.43529412],
         [0.41176471, 0.43921569, 0.41568627],
         ...,
         [0.28235294, 0.31764706, 0.31372549],
         [0.28235294, 0.31372549, 0.30980392],
         [0.28235294, 0.31372549, 0.30980392]],

        [[0.43529412, 0.4627451 , 0.43137255],
         [0.40784314, 0.43529412, 0.40784314],
         [0.38823529, 0.41568627, 0.38431373],
         ...,
         [0.26666667, 0.29411765, 0.28627451],
         [0.2745098 , 0.29803922, 0.29411765],
         [0.30588235, 0.32941176, 0.32156863]],

        [[0.41568627, 0.44313725, 0.41176471],
         [0.38823529, 0.41568627, 0.38431373],
         [0.37254902, 0.4       , 0.36862745],
         ...,
         [0.30588235, 0.33333333, 0.3254902 ],
         [0.30980392, 0.33333333, 0.3254902 ],
         [0.31372549, 0.3372549 , 0.32941176]]],


       ...,


       [[[0.1372549 , 0.69803922, 0.92156863],
         [0.15686275, 0.69019608, 0.9372549 ],
         [0.16470588, 0.69019608, 0.94509804],
         ...,
         [0.38823529, 0.69411765, 0.85882353],
         [0.30980392, 0.57647059, 0.77254902],
         [0.34901961, 0.58039216, 0.74117647]],

        [[0.22352941, 0.71372549, 0.91764706],
         [0.17254902, 0.72156863, 0.98039216],
         [0.19607843, 0.71764706, 0.94117647],
         ...,
         [0.61176471, 0.71372549, 0.78431373],
         [0.55294118, 0.69411765, 0.80784314],
         [0.45490196, 0.58431373, 0.68627451]],

        [[0.38431373, 0.77254902, 0.92941176],
         [0.25098039, 0.74117647, 0.98823529],
         [0.27058824, 0.75294118, 0.96078431],
         ...,
         [0.7372549 , 0.76470588, 0.80784314],
         [0.46666667, 0.52941176, 0.57647059],
         [0.23921569, 0.30980392, 0.35294118]],

        ...,

        [[0.28627451, 0.30980392, 0.30196078],
         [0.20784314, 0.24705882, 0.26666667],
         [0.21176471, 0.26666667, 0.31372549],
         ...,
         [0.06666667, 0.15686275, 0.25098039],
         [0.08235294, 0.14117647, 0.2       ],
         [0.12941176, 0.18823529, 0.19215686]],

        [[0.23921569, 0.26666667, 0.29411765],
         [0.21568627, 0.2745098 , 0.3372549 ],
         [0.22352941, 0.30980392, 0.40392157],
         ...,
         [0.09411765, 0.18823529, 0.28235294],
         [0.06666667, 0.1372549 , 0.20784314],
         [0.02745098, 0.09019608, 0.1254902 ]],

        [[0.17254902, 0.21960784, 0.28627451],
         [0.18039216, 0.25882353, 0.34509804],
         [0.19215686, 0.30196078, 0.41176471],
         ...,
         [0.10588235, 0.20392157, 0.30196078],
         [0.08235294, 0.16862745, 0.25882353],
         [0.04705882, 0.12156863, 0.19607843]]],


       [[[0.74117647, 0.82745098, 0.94117647],
         [0.72941176, 0.81568627, 0.9254902 ],
         [0.7254902 , 0.81176471, 0.92156863],
         ...,
         [0.68627451, 0.76470588, 0.87843137],
         [0.6745098 , 0.76078431, 0.87058824],
         [0.6627451 , 0.76078431, 0.8627451 ]],

        [[0.76078431, 0.82352941, 0.9372549 ],
         [0.74901961, 0.81176471, 0.9254902 ],
         [0.74509804, 0.80784314, 0.92156863],
         ...,
         [0.67843137, 0.75294118, 0.8627451 ],
         [0.67058824, 0.74901961, 0.85490196],
         [0.65490196, 0.74509804, 0.84705882]],

        [[0.81568627, 0.85882353, 0.95686275],
         [0.80392157, 0.84705882, 0.94117647],
         [0.8       , 0.84313725, 0.9372549 ],
         ...,
         [0.68627451, 0.74901961, 0.85098039],
         [0.6745098 , 0.74509804, 0.84705882],
         [0.6627451 , 0.74901961, 0.84313725]],

        ...,

        [[0.81176471, 0.78039216, 0.70980392],
         [0.79607843, 0.76470588, 0.68627451],
         [0.79607843, 0.76862745, 0.67843137],
         ...,
         [0.52941176, 0.51764706, 0.49803922],
         [0.63529412, 0.61960784, 0.58823529],
         [0.65882353, 0.63921569, 0.59215686]],

        [[0.77647059, 0.74509804, 0.66666667],
         [0.74117647, 0.70980392, 0.62352941],
         [0.70588235, 0.6745098 , 0.57647059],
         ...,
         [0.69803922, 0.67058824, 0.62745098],
         [0.68627451, 0.6627451 , 0.61176471],
         [0.68627451, 0.6627451 , 0.60392157]],

        [[0.77647059, 0.74117647, 0.67843137],
         [0.74117647, 0.70980392, 0.63529412],
         [0.69803922, 0.66666667, 0.58431373],
         ...,
         [0.76470588, 0.72156863, 0.6627451 ],
         [0.76862745, 0.74117647, 0.67058824],
         [0.76470588, 0.74509804, 0.67058824]]],


       [[[0.89803922, 0.89803922, 0.9372549 ],
         [0.9254902 , 0.92941176, 0.96862745],
         [0.91764706, 0.9254902 , 0.96862745],
         ...,
         [0.85098039, 0.85882353, 0.91372549],
         [0.86666667, 0.8745098 , 0.91764706],
         [0.87058824, 0.8745098 , 0.91372549]],

        [[0.87058824, 0.86666667, 0.89803922],
         [0.9372549 , 0.9372549 , 0.97647059],
         [0.91372549, 0.91764706, 0.96470588],
         ...,
         [0.8745098 , 0.8745098 , 0.9254902 ],
         [0.89019608, 0.89411765, 0.93333333],
         [0.82352941, 0.82745098, 0.8627451 ]],

        [[0.83529412, 0.80784314, 0.82745098],
         [0.91764706, 0.90980392, 0.9372549 ],
         [0.90588235, 0.91372549, 0.95686275],
         ...,
         [0.8627451 , 0.8627451 , 0.90980392],
         [0.8627451 , 0.85882353, 0.90980392],
         [0.79215686, 0.79607843, 0.84313725]],

        ...,

        [[0.58823529, 0.56078431, 0.52941176],
         [0.54901961, 0.52941176, 0.49803922],
         [0.51764706, 0.49803922, 0.47058824],
         ...,
         [0.87843137, 0.87058824, 0.85490196],
         [0.90196078, 0.89411765, 0.88235294],
         [0.94509804, 0.94509804, 0.93333333]],

        [[0.5372549 , 0.51764706, 0.49411765],
         [0.50980392, 0.49803922, 0.47058824],
         [0.49019608, 0.4745098 , 0.45098039],
         ...,
         [0.70980392, 0.70588235, 0.69803922],
         [0.79215686, 0.78823529, 0.77647059],
         [0.83137255, 0.82745098, 0.81176471]],

        [[0.47843137, 0.46666667, 0.44705882],
         [0.4627451 , 0.45490196, 0.43137255],
         [0.47058824, 0.45490196, 0.43529412],
         ...,
         [0.70196078, 0.69411765, 0.67843137],
         [0.64313725, 0.64313725, 0.63529412],
         [0.63921569, 0.63921569, 0.63137255]]]])

对label数据进行预处理

进行一位有效编码和MNIST数据集的处理是一样的

from keras.utils import np_utils
y_train_onehot=np_utils.to_categorical(y_train)
y_test_onehot=np_utils.to_categorical(y_test)

查看转化后的label标签字段

y_train_onehot.shape
y_train_onehot[0]
array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0.], dtype=float32)

建立卷积神经网络模型

from keras.models import Sequential
from keras.layers import Dense,Dropout,Activation,Flatten
from keras.layers import Conv2D,MaxPooling2D,ZeroPadding2D
建立线性堆叠模型
model=Sequential()
建立卷积层1
model.add(Conv2D(filters=32,kernel_size=(3,3),
         input_shape=(32,32,3),activation='relu',
         padding='same'))
model.add(Dropout(rate=0.25))#设置dropout为25%
建立池化层1
model.add(MaxPooling2D(pool_size=(2,2)))
#将32*32的图像变为16*16,图像的个数仍然为32个
建立卷积层2
model.add(Conv2D(filters=64,kernel_size=(3,3),
                activation='relu',padding='same'))
#将图像个数变为64,图像的大小仍然为16*16
model.add(Dropout(rate=0.25))#避免过拟合
建立池化层2
model.add(MaxPooling2D(pool_size=(2,2)))
建立平坦层
model.add(Flatten())
model.add(Dropout(rate=0.25))
建立隐藏层
model.add(Dense(1024,activation='relu'))
model.add(Dropout(rate=0.25))#避免过拟合
建立输出层
model.add(Dense(10,activation='softmax'))
查看模型摘要
print(model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 32, 32, 32)        896       
_________________________________________________________________
dropout_1 (Dropout)          (None, 32, 32, 32)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 16, 16, 64)        18496     
_________________________________________________________________
dropout_2 (Dropout)          (None, 16, 16, 64)        0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 8, 8, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 4096)              0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 4096)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              4195328   
_________________________________________________________________
dropout_4 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                10250     
=================================================================
Total params: 4,224,970
Trainable params: 4,224,970
Non-trainable params: 0
_________________________________________________________________
None

进行训练

定义训练方式
model.compile(loss='categorical_crossentropy',#设置损失函数
             optimizer='adam',metrics=['accuracy'])#设置优化器和定义评估模型方式
train_history=model.fit(x_train_normalize,y_train_onehot,
                       validation_split=0.2,
                       epochs=10,batch_size=128,verbose=2)#训练周期为10,每次训练128项数据
Train on 40000 samples, validate on 10000 samples
Epoch 1/10
 - 310s - loss: 1.5489 - acc: 0.4436 - val_loss: 1.3615 - val_acc: 0.5575
Epoch 2/10
 - 294s - loss: 1.1747 - acc: 0.5818 - val_loss: 1.1727 - val_acc: 0.6203
Epoch 3/10
 - 277s - loss: 1.0407 - acc: 0.6336 - val_loss: 1.0580 - val_acc: 0.6687
Epoch 4/10
 - 257s - loss: 0.9329 - acc: 0.6701 - val_loss: 0.9902 - val_acc: 0.6756
Epoch 5/10
 - 257s - loss: 0.8434 - acc: 0.7012 - val_loss: 0.9386 - val_acc: 0.6915
Epoch 6/10
 - 3895s - loss: 0.7594 - acc: 0.7327 - val_loss: 0.8906 - val_acc: 0.7139
Epoch 7/10
 - 338s - loss: 0.6715 - acc: 0.7650 - val_loss: 0.8480 - val_acc: 0.7154
Epoch 8/10
 - 335s - loss: 0.6106 - acc: 0.7878 - val_loss: 0.8086 - val_acc: 0.7290
Epoch 9/10
 - 361s - loss: 0.5374 - acc: 0.8103 - val_loss: 0.7926 - val_acc: 0.7290
Epoch 10/10
 - 394s - loss: 0.4856 - acc: 0.8293 - val_loss: 0.7823 - val_acc: 0.7344
可视化准确率和误差
import matplotlib.pyplot as plt
def show_train_history(train_history,train,validation):#输入参数分别为,train_history,
                                                        #训练数据的执行结果,验证数据的执行结果
    plt.plot(train_history.history[train])
    plt.plot(train_history.history[validation])
    plt.title('Train History')
    plt.ylabel(train)
    plt.xlabel('Epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()
show_train_history(train_history,'acc','val_acc')
train_history.history

png

{'acc': [0.443575,
  0.581775,
  0.63355,
  0.670125,
  0.7012,
  0.732675,
  0.764975,
  0.787775,
  0.810325,
  0.82925],
 'loss': [1.5489341194152833,
  1.1746543694496154,
  1.0407079723358155,
  0.9328960960388184,
  0.8434274671554566,
  0.7594253509521485,
  0.671516370010376,
  0.6106255568504334,
  0.5374301322937012,
  0.48556572380065915],
 'val_acc': [0.5575,
  0.6203,
  0.6687,
  0.6756,
  0.6915,
  0.7139,
  0.7154,
  0.729,
  0.729,
  0.7344],
 'val_loss': [1.3614615459442139,
  1.1727069728851318,
  1.0579507133483887,
  0.9902402221679687,
  0.9386354804992676,
  0.8906289207458497,
  0.8480340557098389,
  0.8086261804580689,
  0.7925947729110718,
  0.7822921306610108]}
show_train_history(train_history,'loss','val_loss')

png

评估模型准确率

score=model.evaluate(x_test_normalize,y_test_onehot)
score[1]
10000/10000 [==============================] - 31s 3ms/step





0.7284
进行预测
prediction=model.predict_classes(x_test_normalize)
查看预测结果
prediction[:10]
array([3, 8, 8, 0, 6, 6, 1, 2, 3, 1], dtype=int64)
查看前十项预测结果
plot_images_labels_prediction(x_test,y_test,prediction,0,10)

png

查看预测概率
probability=model.predict(x_test_normalize)
probability.shape
(10000, 10)
def show_probability(y,prediction,x,#y为真实值,prediction为预测结果,x为预测的图像
                     probability,i):#probability为预测概率,i表示开始显示数据的index
    print('label:',dict[y[i][0]],
         'prediction:',dict[prediction[i]])
    plt.figure(figsize=(2,2))
    plt.imshow(np.reshape(x_test[i],(32,32,3)))
    plt.show()
    for j in range(10):
        print(dict[j]+'\t\tprobability:%f'%(probability[i][j]))
    
show_probability(y_test,prediction,x_test,probability,0)#第0项数据的预测结果
label: cat prediction: cat

png

airplane		probability:0.000951
automobile		probability:0.001414
bird		probability:0.015849
cat		probability:0.662395
deer		probability:0.006565
dog		probability:0.283635
frog		probability:0.017295
horse		probability:0.003373
ship		probability:0.006642
truck		probability:0.001881
show_probability(y_test,prediction,x_test,probability,3)#查看第三项数据的预测概率
label: airplane prediction: airplane

png

airplane		probability:0.780524
automobile		probability:0.010881
bird		probability:0.056743
cat		probability:0.000491
deer		probability:0.003982
dog		probability:0.000053
frog		probability:0.000369
horse		probability:0.000228
ship		probability:0.145283
truck		probability:0.001446

显示混淆矩阵

建立混淆矩阵
prediction
array([3, 8, 8, ..., 5, 1, 7], dtype=int64)
prediction.shape
(10000,)
y_test
array([[3],
       [8],
       [8],
       ...,
       [5],
       [1],
       [7]])
y_test.reshape(-1)#将y_test变为一维数组
array([3, 8, 8, ..., 5, 1, 7])
import pandas as pd
print(dict)
pd.crosstab(y_test.reshape(-1),prediction,
           rownames=['label'],colnames=['prediction'])
{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
prediction 0 1 2 3 4 5 6 7 8 9
label
0 788 17 52 6 9 6 22 11 55 34
1 6 818 14 6 3 4 17 4 26 102
2 58 3 627 24 111 54 78 24 13 8
3 24 9 83 416 94 182 136 26 12 18
4 18 2 65 22 719 25 80 57 10 2
5 16 2 45 118 68 622 64 46 7 12
6 5 5 38 21 25 15 879 4 5 3
7 15 1 44 21 65 46 18 773 4 13
8 51 32 24 7 10 5 8 2 831 30
9 32 62 18 6 2 8 14 12 35 811

结论

  • 蛙类最不容易混淆,猫类最容易混淆
  • 狗很容易被认为是猫,应为真实值为5预测为3有118次
  • 由上图同理可以看出猫也十分容易被预测成狗
  • 2,3,4,5,6预测为1的数量都十分少

最后反思,本实验的准确率不高,因为模型比较简单,初衷是为了迅速看到实验的OUTPUT,加上我的处理器很垃圾,运行时间开销很大,条件允许的话,解决方法:可以增加卷积层,在Flatten层中建立多个隐藏层,并提高神经元的个数

发布了44 篇原创文章 · 获赞 82 · 访问量 9万+

猜你喜欢

转载自blog.csdn.net/weixin_41503009/article/details/86745356