cbam block and senet block

其实就是在 图像宽高上面重要性 图像宽高求要平均 最大 然后 相加 求sigmoid 

然后 接着在 通道上面求重要性 通道数上面分别求平均 和 最大  然后相加  求sigmoid

senet  其实就是通道的重要性不同 ,感觉cbam和senet是不是有些相似 

def cbam_block(input_feature, name, ratio=8):
    """Contains the implementation of Convolutional Block Attention Module(CBAM) block.
    As described in https://arxiv.org/abs/1807.06521.
    """

    with tf.variable_scope(name):
        attention_feature = channel_attention(input_feature, 'ch_at', ratio)  # input  (2, 35, 35, 384)  # attention_feature(2, 35, 35, 384) 
        attention_feature = spatial_attention(attention_feature, 'sp_at')# attention_feature  (2, 35, 35, 384)
    return attention_feature

def channel_attention(input_feature, name, ratio=8):

    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    bias_initializer = tf.constant_initializer(value=0.0)

    with tf.variable_scope(name):

        channel = input_feature.get_shape()[-1]  # input_feature (2, 35, 35, 384)  channel 384   
        avg_pool = tf.reduce_mean(input_feature, axis=[1,2], keepdims=True)  # =(2, 1, 1, 384)

        assert avg_pool.get_shape()[1:] == (1,1,channel)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                               units=channel//ratio,  #units 48  ratio 8
                                 activation=tf.nn.relu,
                                 kernel_initializer=kernel_initializer,
                                 bias_initializer=bias_initializer,
                                 name='mlp_0',
                                 reuse=None)   
        assert avg_pool.get_shape()[1:] == (1,1,channel//ratio)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                               units=channel,                             
                                 kernel_initializer=kernel_initializer,
                                 bias_initializer=bias_initializer,
                                 name='mlp_1',
                                 reuse=None)    
        assert avg_pool.get_shape()[1:] == (1,1,channel) #avg_pool  (2, 1, 1, 384)

        max_pool = tf.reduce_max(input_feature, axis=[1,2], keepdims=True)  #(2, 1, 1, 384)
        assert max_pool.get_shape()[1:] == (1,1,channel)
        max_pool = tf.layers.dense(inputs=max_pool,
                               units=channel//ratio,
                                 activation=tf.nn.relu,
                                 name='mlp_0',
                                 reuse=True)    #max_pool  shape=(2, 1, 1, 48)
        assert max_pool.get_shape()[1:] == (1,1,channel//ratio)
        max_pool = tf.layers.dense(inputs=max_pool,
                               units=channel,                             
                                 name='mlp_1',
                                 reuse=True)    #  max_pool  (2, 1, 1, 384)
        assert max_pool.get_shape()[1:] == (1,1,channel)
        scale = tf.sigmoid(avg_pool + max_pool, 'sigmoid')  #池化相加sigmoid  

    return input_feature * scale  #  其实就是 图像的每个通道*不同的比重   

def spatial_attention(input_feature, name):  # input_feature  shape=(2, 35, 35, 384)
    kernel_size = 7
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    with tf.variable_scope(name):
        avg_pool = tf.reduce_mean(input_feature, axis=[3], keepdims=True) # avg_pool  =(2, 35, 35, 1) 在通道上求平均和最大   
        assert avg_pool.get_shape()[-1] == 1
        max_pool = tf.reduce_max(input_feature, axis=[3], keepdims=True) # max_pool  shape=(2, 35, 35, 1)
        assert max_pool.get_shape()[-1] == 1
        concat = tf.concat([avg_pool,max_pool], 3)  # concat (2, 35, 35, 2)
        assert concat.get_shape()[-1] == 2

        concat = tf.layers.conv2d(concat,
                              filters=1,
                              kernel_size=[kernel_size,kernel_size],
                              strides=[1,1],
                              padding="same",
                              activation=None,
                              kernel_initializer=kernel_initializer,
                              use_bias=False,
                              name='conv')
        assert concat.get_shape()[-1] == 1  #cpncat   (2, 35, 35, 1)
        concat = tf.sigmoid(concat, 'sigmoid') #  (2, 35, 35, 1)

    return input_feature * concat  #  (2, 35, 35, 384)  # 就是图片的每个维度 的图片都乘以 (35,35)仔细琢磨一下  

猜你喜欢

转载自blog.csdn.net/candy134834/article/details/85160899