TF下加载一个VGG19的网络

网络模型及测试图片放在一个文件夹下面

#coding=utf-8
import tensorflow as tf
import scipy.io
import scipy.misc
import os
import numpy as np
import matplotlib.pyplot as plt

def nets(data_path,input_img):
    layers = (
            'conv1_1','relu1_1','conv1_2','relu1_2','pool1',
            'conv2_1','relu2_1','conv2_2','relu2_2','pool2',
            'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3',
            'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4',
            'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4'
            )
    data = scipy.io.loadmat(data_path)
    mean = data['normalization'][0][0][0]
    mean_pixel = np.mean(mean, (0,1))
    weigths = data['layers'][0]
    _nets = {}
    #记录以此前向传播的结果
    current = input_img
    for i,name in enumerate(layers):
        kind = name[:4]
        if kind == 'conv':
            kernels,biase = weigths[i][0][0][0][0]
            kernels = np.transpose(kernels,(1,0,2,3))
            biase = np.reshape(biase,(-1))
            #kernels因为已经固定好了
            current = tf.nn.conv2d(current, tf.constant(kernels), [1,1,1,1], padding = 'SAME')
            current = tf.nn.bias_add(current, biase)
        if kind == 'relu':
            current = tf.nn.relu(current)
        if kind == 'pool':
            current = tf.nn.max_pool(current, [1,2,2,1],[1,2,2,1], padding = 'SAME')
        _nets[name] = current   
    assert len(_nets) == len(layers)
    return _nets,layers,mean_pixel

cwd = os.getcwd()
data_path = cwd + '/data/imagenet-vgg-verydeep-19.mat'
img_path = cwd + '/data/horse.jpg'

input_img = scipy.misc.imread(img_path).astype(np.float32)
#batch_size,h,w,chanel
shape = (1,input_img.shape[0],input_img.shape[1],input_img.shape[2])

with tf.Session() as sess:
    image = tf.placeholder(tf.float32, shape = shape)
    net, layers, mean_pixel = nets(data_path,image)
    img_prepocess = np.array([input_img  - mean_pixel])
    ax = [ _ for _ in range(len(layers))]
    figure = plt.figure(figsize=(24,12)) 
    
    for i,layer in enumerate(layers):
        print('[%d/%d] %s' % (i+1, len(layers), layer))
        features = net[layer].eval(feed_dict = {image:img_prepocess})
        print('type of feature:{},shape is {}'.format(type(features),features.shape))
        
        ax[i] = figure.add_subplot(4,9,i+1)
        plt.imshow(features[0,:, :, 0],cmap = plt.cm.gray)
        plt.title('' + layer)
          #这个是单步显示,太麻烦,合成一张了
#         if True:
#             plt.figure(i+1,figsize = (8,6))
#             plt.matshow(features[0,:, :, 0],cmap = plt.cm.gray, fignum = i+1)
#             plt.title('' + layer)
#             plt.colorbar()
#             plt.show()
#         
    #要保存图片需要在show之前使用
    plt.savefig('2.png')
    plt.show() 
print('Done')
 
 

以下输入的图片,及最后的测试结果,可以查看这个网络学习到的特征


此处生成图片时,格式一定要注意,png格式,别的好像存不下来


猜你喜欢

转载自blog.csdn.net/dyh02016/article/details/80026045
今日推荐