Tensorflow.keras implementation of VGG-13 16 19 network

Pay attention to a few points:

  • 1. The data set used is downloaded from ImageNet, but the amount of data is relatively small, and the VGG network is very large, so it is definitely not enough to train. Of course, if you want to train the network, you can use the method of migration learning.
  • 2. All knowledge points and difficulties will be explained in detail in the notes.
  • 3. There are three modules in the whole code. Basically, except for the model definition, the rest is roughly similar to the previous AlexNet network code.

Model script:

from tensorflow.keras import layers, models, Model, Sequential

#定义分类网络结构 即最后的全连接层
def VGG(feature, im_height=224, im_width=224, class_num=1000):#feature是提取特征的网络结构
    # tensorflow中的tensor通道排序是NHWC
    input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32")
    x = feature(input_image)#提取特征得到输出
    x = layers.Flatten()(x)#展平处理
    x = layers.Dropout(rate=0.5)(x)#加一个dropout方法 减小过拟合
    x = layers.Dense(2048, activation='relu')(x)#为了节省训练参数 设置原论文一半的节点
    x = layers.Dropout(rate=0.5)(x)
    x = layers.Dense(2048, activation='relu')(x)
    x = layers.Dense(class_num)(x)
    output = layers.Softmax()(x)
    model = models.Model(inputs=input_image, outputs=output)
    return model

#通过配置列表生成提取特征的网络结构
def features(cfg):
    feature_layers = []#用来存放层结构
    for v in cfg:#通过for循环来遍历配置列表
        if v == "M":#说明该层是最大池化层
            feature_layers.append(layers.MaxPool2D(pool_size=2, strides=2))
        else:
            conv2d = layers.Conv2D(v, kernel_size=3, padding="SAME", activation="relu")
            feature_layers.append(conv2d)
    return Sequential(feature_layers, name="feature")#name是给网络结构起的一个名字

#字典:用来存储不同配置的模型结构 键是模型的配置文件,值是列表类型,其中的数字代表卷积层卷积核的个数,M代表池化层的结构(最大池化操作)
cfgs = {
    
    
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


def vgg(model_name="vgg16", im_height=224, im_width=224, class_num=1000):#实例化模型 参数一:字典的key
    try:
        cfg = cfgs[model_name]#获得值
    except:
        print("Warning: model number {} not in cfgs dict!".format(model_name))
        exit(-1)
    model = VGG(features(cfg), im_height=im_height, im_width=im_width, class_num=class_num)
    return model

model = vgg(model_name='vgg11') 

Training script:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from model import vgg
import tensorflow as tf
import json
import os


data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set path
train_dir = image_path + "train"
validation_dir = image_path + "val"

# create direction for saving weights
if not os.path.exists("save_weights"):
    os.makedirs("save_weights")

im_height = 224
im_width = 224
batch_size = 10
epochs = 10

# 预处理
train_image_generator = ImageDataGenerator(rescale=1. / 255,#简单的缩放
                                           horizontal_flip=True)#水平方向的随机翻转
validation_image_generator = ImageDataGenerator(rescale=1. / 255)#定义验证集生成器
#读取训练集图像文件
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
                                                           batch_size=batch_size,
                                                           shuffle=True,
                                                           target_size=(im_height, im_width),
                                                           class_mode='categorical')
total_train = train_data_gen.n#获得训练集训练样本的个数

#字典类型,返回每个类别和其索引
class_indices = train_data_gen.class_indices

# 将key和value进行反转 得到反过来的字典 (目的:在预测的过程中通过索引直接对应到类别中)
inverse_dict = dict((val, key) for key, val in class_indices.items())
# python对象转换成json对象的一个过程,生成的是字符串。
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:#将所得到的字典写入到json文件当中
    json_file.write(json_str)
#读取验证集图像文件
val_data_gen = train_image_generator.flow_from_directory(directory=validation_dir,
                                                         batch_size=batch_size,
                                                         shuffle=True,
                                                         target_size=(im_height, im_width),
                                                         class_mode='categorical')
total_val = val_data_gen.n

model = vgg("vgg16", 224, 224, 5)#实例化网络
model.summary()

# using keras high level api for training
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              metrics=["accuracy"])

callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
                                                save_best_only=True,
                                                save_weights_only=True,
                                                monitor='val_loss')]

# tensorflow2.1 recommend to using fit
history = model.fit(x=train_data_gen,
                    steps_per_epoch=total_train // batch_size,
                    epochs=epochs,
                    validation_data=val_data_gen,
                    validation_steps=total_val // batch_size,
                    callbacks=callbacks)

Forecast script:

from model import vgg
from PIL import Image
import numpy as np
import json
import matplotlib.pyplot as plt

im_height = 224
im_width = 224

# load image
img = Image.open("../tulip.jpg")
# resize image to 224x224
img = img.resize((im_width, im_height))
plt.imshow(img)

# scaling pixel value to (0-1)
img = np.array(img) / 255.

# Add the image to a batch where it's the only member.
img = (np.expand_dims(img, 0))

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

model = vgg("vgg16", 224, 224, 5)
model.load_weights("./save_weights/myVGG.h5")
result = np.squeeze(model.predict(img))
predict_class = np.argmax(result)
print(class_indict[str(predict_class)], result[predict_class])
plt.show()

Guess you like

Origin blog.csdn.net/qq_42308217/article/details/110350766