一种图像列表传入神经网络的方法

一种图像列表传入神经网络的方法

首先感谢Faster-RCNN_TF的作者,本文主要参考了该代码的一部分内容。


阅读Faster_RCNN-TF的源码对我产生了很大的帮助,在Kaggle的肺炎图片分类挑战中,需要把数据feed到网络中,一般来说我会先把图片数据读入一个数组中,然后通过feed_dict传入网络。
但是如果使用list类型传入数据会报错,所以参考了Faster-RCNN_TF中的方法,具体请参考:
./roi_data_layer/layer.py
./roi_data_layer/minibatch.py
./utils/blob.py
其中最关键的函数是:

#ims是一个list,[img0,img1...],其中img0的shape为[img_width,img_height,channels]
def im_list_to_blob(self,ims):  
    #每张图像的shape都是[img_width,img_height,channels],但是每张图像的大小可能不同
    max_shape = np.array([im.shape for im in ims]).max(axis=0)
    num_images = len(ims)
    #初始化blob,为了解决不同大小图像都可以装入blob,使用max_shape初始化
    blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
                    dtype=np.float32)
    for i in range(num_images):
        im = ims[i]
        #每个blob只使用和imgs[i]一样的大小的区域存储图像,多余的像素点为0
        blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
    return blob

测试代码

import numpy as np

ims = []
b = np.array([[[1,1,1],[1,1,1],[1,1,1]],
              [[2,2,2],[2,2,2],[2,2,2]],
              [[3,3,3],[3,3,3],[3,3,3]]])

b1 = np.array([[[1,1,1],[1,1,1],[1,1,1]],
               [[2,2,2],[2,2,2],[2,2,2]],
               [[3,3,3],[3,3,3],[3,3,3]],
               [[4,4,4],[4,4,4],[4,4,4]]])

for i in range(2):
    ims.append(b)
ims.append(b1)

max_shape = np.array([im.shape for im in ims]).max(axis=0)
num_images = len(ims)
blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
                dtype=np.float32)
for i in range(num_images):
    im = ims[i]
    blob[i, 0:im.shape[0], 0:im.shape[1], :] = im

print(blob)

测试结果

[[[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[2. 2. 2.]
   [2. 2. 2.]
   [2. 2. 2.]]

  [[3. 3. 3.]
   [3. 3. 3.]
   [3. 3. 3.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]     #<<这是第一个img,由于blob大于imgshape,所以blob的第4行为0


 [[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[2. 2. 2.]
   [2. 2. 2.]
   [2. 2. 2.]]

  [[3. 3. 3.]
   [3. 3. 3.]
   [3. 3. 3.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]    #同上


 [[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[2. 2. 2.]
   [2. 2. 2.]
   [2. 2. 2.]]

  [[3. 3. 3.]
   [3. 3. 3.]
   [3. 3. 3.]]

  [[4. 4. 4.]
   [4. 4. 4.]
   [4. 4. 4.]]]]    #<<imgblob的大小相同,所以填满了

猜你喜欢

转载自blog.csdn.net/wangdongwei0/article/details/81212648
今日推荐