用tensorflow对cifar10数据集进行分类

用tensorflow对cifar10数据集进行分类


1.使用软件版本
①Anaconda3.8.11
②PyCharm Community Edition 2021.2
③tensorflow2.4.0

2.所使用数据集(cifar10)

3.训练代码实现

import tensorflow as tf
from tensorflow import keras
cifar10 = keras.datasets.cifar10
(train_images,train_lables),_=cifar10.load_data()#先下载数据
train_images=train_images/255.0#将0-255的像素值转化为0.0-1.0范围内的实数
train_images=tf.image.resize(train_images[:20000],(224,224))#将图片转为224*224大小的图片,[:20000]为取20000张图片进行训练
train_lables=train_lables[:20000]

alexNet=keras.Sequential(layers=[
    #conv1 and pool1
    #96个卷积核,卷积核大小11*11,步长为4,填充方式为valid填充,激活函数为relu函数
    keras.layers.Conv2D(96,11,strides=(4,4),padding='valid',activation='relu'),
    keras.layers.BatchNormalization(),#对前一层的激活进行每个batch的归一化,使得激活后的参数近似于(0,1)正态分布
    #最大池化层,池化窗口大小为3*3,步长为2,填充方式为valid填充
    keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='valid'),

    #conv2 and pool2
    keras.layers.Conv2D(256,5,strides=(1,1),padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='valid'),

    #conv3
    keras.layers.Conv2D(384,3,strides=(1*1),padding='same',activation='relu'),

    #conv4
    keras.layers.Conv2D(384,3,strides=(1,1),padding='same',activation='relu'),

    #conv5 and pool3
    keras.layers.Conv2D(256,3,strides=(1,1),padding='same',activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3),strides=(2,2),padding='valid'),

    #Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡
    keras.layers.Flatten(),

    #全连接层第一层
    keras.layers.Dense(4096,activation='relu'),#把维度降到4096
    keras.layers.Dropout(rate=0.5),

    #全连接层第二层
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dropout(rate=0.5),
    keras.layers.Dense(10)
])#构建alexNet模型

#模型配置
alexNet.compile(optimizer='adam',
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

#训练模型
alexNet.fit(x=train_images,y=train_lables,batch_size=16,epochs=10)

4.运行结果
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/Airmilan/article/details/120982957