Mise en œuvre du code CBAM du mécanisme d'attention (suite)

Cliquez ici pour voir l'article précédent
Github: kobiso / CBAM-tensorflow
Le code suivant est le module d'attention de re-canal, et les descripteurs obtenus après la mise en commun moyenne et maximale sont respectivement saisis dans le MLP. Vous pouvez également voir les deux descripteurs dans le code. Obtenez le partage de poids MLP.

def cbam_block(input_feature, index, reduction_ratio=8):
    with tf.variable_scope('cbam_%s' % index):
        attention_feature = channel_attention(input_feature, index, reduction_ratio)
        attention_feature = spatial_attention(attention_feature, index)
        print("hello CBAM")
    return attention_feature


def channel_attention(input_feature, index, reduction_ratio=8):
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    bias_initializer = tf.constant_initializer(value=0.0)

    with tf.variable_scope('ch_attention_%s' % index):
        feature_map_shape = input_feature.get_shape()
        channel = input_feature.get_shape()[-1]
        avg_pool = tf.nn.avg_pool(value=input_feature,
                                  ksize=[1, feature_map_shape[1], feature_map_shape[2], 1],
                                  strides=[1, 1, 1, 1],
                                  padding='VALID')
        assert avg_pool.get_shape()[1:] == (1, 1, channel)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                                   units=channel//reduction_ratio,
                                   activation=tf.nn.relu,
                                   kernel_initializer=kernel_initializer,
                                   bias_initializer=bias_initializer,
                                   name='mlp_0',
                                   reuse=None)
        assert avg_pool.get_shape()[1:] == (1, 1, channel//reduction_ratio)
        avg_pool = tf.layers.dense(inputs=avg_pool,
                                   units=channel,
                                   kernel_initializer=kernel_initializer,
                                   bias_initializer=bias_initializer,
                                   name='mlp_1',
                                   reuse=None)
        assert avg_pool.get_shape()[1:] == (1, 1, channel)

        max_pool = tf.nn.max_pool(value=input_feature,
                                  ksize=[1, feature_map_shape[1], feature_map_shape[2], 1],
                                  strides=[1, 1, 1, 1],
                                  padding='VALID')
        assert max_pool.get_shape()[1:] == (1, 1, channel)
        max_pool = tf.layers.dense(inputs=max_pool,
                                   units=channel//reduction_ratio,
                                   activation=tf.nn.relu,
                                   name='mlp_0',
                                   reuse=True)
        assert max_pool.get_shape()[1:] == (1, 1, channel//reduction_ratio)
        max_pool = tf.layers.dense(inputs=max_pool,
                                   units=channel,
                                   name='mlp_1',
                                   reuse=True)
        assert max_pool.get_shape()[1:] == (1, 1, channel)
        scale = tf.nn.sigmoid(avg_pool + max_pool)
    return input_feature * scale


def spatial_attention(input_feature, index):
    kernel_size = 7
    kernel_initializer = tf.contrib.layers.variance_scaling_initializer()
    with tf.variable_scope("sp_attention_%s" % index):
        avg_pool = tf.reduce_mean(input_feature, axis=3, keepdims=True)
        assert avg_pool.get_shape()[-1] == 1
        max_pool = tf.reduce_max(input_feature, axis=3, keepdims=True)
        assert max_pool.get_shape()[-1] == 1
        # 按通道拼接
        concat = tf.concat([avg_pool, max_pool], axis=3)
        assert concat.get_shape()[-1] == 2

        concat = slim.conv2d(concat, num_outputs=1,
                             kernel_size=[kernel_size, kernel_size],
                             padding='SAME',
                             activation_fn=tf.nn.sigmoid,
                             weights_initializer=kernel_initializer,
                             scope='conv')
        assert concat.get_shape()[-1] == 1

    return input_feature * concat

Je suppose que tu aimes

Origine blog.csdn.net/qq_43265072/article/details/106058693
conseillé
Classement