tf.nn.max_pool参数含义和用法

max pooling是CNN当中的最大值池化操作,其实用法和卷积很类似

有些地方可以从卷积去参考【TensorFlow】tf.nn.conv2d是怎样实现卷积的?

tf.nn.max_pool(value, ksize, strides, padding, name=None)
参数是四个,和卷积很类似:

第一个参数value:需要池化的输入,一般池化层接在卷积层后面,所以输入通常是feature map,依然是[batch, height, width, channels]这样的shape

第二个参数ksize:池化窗口的大小,取一个四维向量,一般是[1, height, width, 1],因为我们不想在batch和channels上做池化,所以这两个维度设为了1

第三个参数strides:和卷积类似,窗口在每一个维度上滑动的步长,一般也是[1, stride,stride, 1]

第四个参数padding:和卷积类似,可以取’VALID’ 或者’SAME’。padding = ‘SAME’ 时,输出并不一定和原图size一致,但会保证覆盖原图所有像素,不会舍弃边上的莫些元素;padding = ‘VALID’ 时,输出的size总比原图的size小,有时不会覆盖原图所有元素(既,可能舍弃边上的某些元素).

返回一个Tensor,类型不变,shape仍然是[batch, height, width, channels]这种形式

示例源码:

假设有这样一张图,双通道

第一个通道:
在这里插入图片描述
第二个通道:
在这里插入图片描述
用程序去做最大值池化:

     import tensorflow as tf  
      
    a=tf.constant([  
            [[1.0,2.0,3.0,4.0],  
            [5.0,6.0,7.0,8.0],  
            [8.0,7.0,6.0,5.0],  
            [4.0,3.0,2.0,1.0]],  
            [[4.0,3.0,2.0,1.0],  
             [8.0,7.0,6.0,5.0],  
             [1.0,2.0,3.0,4.0],  
             [5.0,6.0,7.0,8.0]]  
        ])  
      
    a=tf.reshape(a,[1,4,4,2])  
      
    pooling=tf.nn.max_pool(a,[1,2,2,1],[1,1,1,1],padding='VALID')  
    with tf.Session() as sess:  
        print("image:")  
        image=sess.run(a)  
        print (image)  
        print("reslut:")  
        result=sess.run(pooling)  
        print (result)  

这里步长为1,窗口大小2×2,输出结果:

     image:  
    [[[[ 1.  2.]  
       [ 3.  4.]  
       [ 5.  6.]  
       [ 7.  8.]]  
      
      [[ 8.  7.]  
       [ 6.  5.]  
       [ 4.  3.]  
       [ 2.  1.]]  
      
      [[ 4.  3.]  
       [ 2.  1.]  
       [ 8.  7.]  
       [ 6.  5.]]  
      
      [[ 1.  2.]  
       [ 3.  4.]  
       [ 5.  6.]  
       [ 7.  8.]]]]  
    reslut:  
    [[[[ 8.  7.]  
       [ 6.  6.]  
       [ 7.  8.]]  
      
      [[ 8.  7.]  
       [ 8.  7.]  
       [ 8.  7.]]  
      
      [[ 4.  4.]  
       [ 8.  7.]  
       [ 8.  8.]]]]  

池化后的图就是:
在这里插入图片描述

在这里插入图片描述

证明了程序的结果是正确的。

我们还可以改变步长

     pooling=tf.nn.max_pool(a,[1,2,2,1],[1,2,2,1],padding='VALID')  

最后的result就变成:

     reslut:  
    [[[[ 8.  7.]  
       [ 7.  8.]]  
      
      [[ 4.  4.]  
       [ 8.  8.]]]]  

猜你喜欢

转载自blog.csdn.net/lafengjinzi/article/details/84938488
今日推荐