网络模型及测试图片放在一个文件夹下面
#coding=utf-8 import tensorflow as tf import scipy.io import scipy.misc import os import numpy as np import matplotlib.pyplot as plt def nets(data_path,input_img): layers = ( 'conv1_1','relu1_1','conv1_2','relu1_2','pool1', 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3', 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4', 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4' ) data = scipy.io.loadmat(data_path) mean = data['normalization'][0][0][0] mean_pixel = np.mean(mean, (0,1)) weigths = data['layers'][0] _nets = {} #记录以此前向传播的结果 current = input_img for i,name in enumerate(layers): kind = name[:4] if kind == 'conv': kernels,biase = weigths[i][0][0][0][0] kernels = np.transpose(kernels,(1,0,2,3)) biase = np.reshape(biase,(-1)) #kernels因为已经固定好了 current = tf.nn.conv2d(current, tf.constant(kernels), [1,1,1,1], padding = 'SAME') current = tf.nn.bias_add(current, biase) if kind == 'relu': current = tf.nn.relu(current) if kind == 'pool': current = tf.nn.max_pool(current, [1,2,2,1],[1,2,2,1], padding = 'SAME') _nets[name] = current assert len(_nets) == len(layers) return _nets,layers,mean_pixel cwd = os.getcwd() data_path = cwd + '/data/imagenet-vgg-verydeep-19.mat' img_path = cwd + '/data/horse.jpg' input_img = scipy.misc.imread(img_path).astype(np.float32) #batch_size,h,w,chanel shape = (1,input_img.shape[0],input_img.shape[1],input_img.shape[2]) with tf.Session() as sess: image = tf.placeholder(tf.float32, shape = shape) net, layers, mean_pixel = nets(data_path,image) img_prepocess = np.array([input_img - mean_pixel]) ax = [ _ for _ in range(len(layers))] figure = plt.figure(figsize=(24,12)) for i,layer in enumerate(layers): print('[%d/%d] %s' % (i+1, len(layers), layer)) features = net[layer].eval(feed_dict = {image:img_prepocess}) print('type of feature:{},shape is {}'.format(type(features),features.shape)) ax[i] = figure.add_subplot(4,9,i+1) plt.imshow(features[0,:, :, 0],cmap = plt.cm.gray) plt.title('' + layer) #这个是单步显示,太麻烦,合成一张了 # if True: # plt.figure(i+1,figsize = (8,6)) # plt.matshow(features[0,:, :, 0],cmap = plt.cm.gray, fignum = i+1) # plt.title('' + layer) # plt.colorbar() # plt.show() # #要保存图片需要在show之前使用 plt.savefig('2.png') plt.show() print('Done')
以下输入的图片,及最后的测试结果,可以查看这个网络学习到的特征
此处生成图片时,格式一定要注意,png格式,别的好像存不下来