Residual Attention Network——TensorFlow低阶API实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ZWX2445205419/article/details/89468070

Residual Attention Network

论文: https://arxiv.org/pdf/1704.06904.pdf
作者实现:https://github.com/fwang91/residual-attention-network
本文完全按照作者的CAFFE源代码,使用TensorFlow重新实现

网络图

在这里插入图片描述
在这里插入图片描述

网络基础组件

class BaseNet:
    def __init__(self, inputs, num_classes, is_training=True, trainable=True):
        self.inputs = inputs
        self.num_classes = num_classes
        self.is_training = is_training
        self.trainable = trainable

    def _setup(self):
        raise NotImplementedError

    @staticmethod
    def make_cpu_variable(name, shape, initializer, trainable=True):
        with tf.variable_scope(name):
            return tf.get_variable(name, shape, initializer=initializer, trainable=trainable)

    def conv(self, x, k_h, k_w, s_h, s_w, c_out, name, relu, bias_term=False, padding="SAME", trainable=True):
        with tf.name_scope(name), tf.variable_scope(name):
            c_in = x.get_shape().as_list()[-1]
            weights = self.make_cpu_variable("weights", [k_h, k_w, c_in, c_out],
                                             initializer=layers.xavier_initializer_conv2d(),
                                             trainable=trainable)

            outputs = tf.nn.conv2d(x, weights, [1, s_h, s_w, 1], padding=padding)

            if bias_term:
                biases = self.make_cpu_variable("biases", [c_out],
                                                initializer=tf.constant_initializer(0.001),
                                                trainable=trainable)
                outputs = tf.nn.bias_add(outputs, biases)

            if relu:
                outputs = tf.nn.relu(outputs)

            return outputs

    @staticmethod
    def max_pool(x, k_h, k_w, s_h, s_w, name, padding="SAME"):
        with tf.name_scope(name):
            return tf.nn.max_pool(x, [1, k_h, k_w, 1], [1, s_h, s_w, 1], padding)

    @staticmethod
    def ave_pool(x, k_h, k_w, s_h, s_w, name, padding="VALID"):
        with tf.name_scope(name):
            return tf.nn.avg_pool(x, [1, k_h, k_w, 1], [1, s_h, s_w, 1], padding)

    def fc(self, x, n_out, name, relu, bias_term=False, trainable=True):
        with tf.name_scope(name), tf.variable_scope(name):
            input_shape = x.get_shape().as_list()
            assert len(input_shape) in [1, 4]
            if len(input_shape) == 4:
                dim = 1
                for d in input_shape[1:]:
                    dim *= d
                x = tf.reshape(x, [-1, dim])
            else:
                dim = input_shape[1]

            weights = self.make_cpu_variable("weights", [dim, n_out],
                                             initializer=tf.truncated_normal_initializer(stddev=0.001))
            outputs = tf.matmul(x, weights)
            if bias_term:
                biases = self.make_cpu_variable("biases", [n_out],
                                                initializer=tf.constant_initializer(0.001),
                                                trainable=trainable)
                outputs = tf.nn.bias_add(outputs, biases)

            if relu:
                outputs = tf.nn.relu(outputs)
            return outputs

    @staticmethod
    def relu(x, name):
        with tf.name_scope(name):
            return tf.nn.relu(x)

    @staticmethod
    def dropout(x, keep_prob, name):
        with tf.name_scope(name):
            return tf.nn.dropout(x, keep_prob)

    @staticmethod
    def upsample(x, name, size):
        with tf.name_scope(name):
            return tf.image.resize_bilinear(x, size)

    @staticmethod
    def softmax(x, name):
        with tf.name_scope(name):
            return tf.nn.softmax(x)

    @staticmethod
    def batch_normal(x, is_training, name, activation_fn=None):
        with tf.name_scope(name), tf.variable_scope(name):
            return layers.batch_norm(x, decay=0.9,
                                     zero_debias_moving_mean=True,
                                     activation_fn=activation_fn,
                                     is_training=is_training)

Residual Unit

def residual_unit(self, x, c_in, c_out, name, stride=1):
    with tf.name_scope(name), tf.variable_scope(name):
        bn_1 = self.batch_normal(x, self.is_training, name="bn_1", activation_fn=tf.nn.relu)

        conv_1 = self.conv(bn_1, 1, 1, 1, 1, c_out // 4, "conv_1", relu=False)  # bn之后用了relu,所以这里relu设为False
        bn_2 = self.batch_normal(conv_1, self.is_training, name="bn_2", activation_fn=tf.nn.relu)

        conv_2 = self.conv(bn_2, 3, 3, stride, stride, c_out // 4, "conv_2", relu=False)
        bn_3 = self.batch_normal(conv_2, self.is_training, name="bn_3", activation_fn=tf.nn.relu)

        conv_3 = self.conv(bn_3, 1, 1, 1, 1, c_out, "conv_3", relu=False)

        if c_out != c_in or stride > 1:
            skip = self.conv(bn_1, 1, 1, stride, stride, c_out, "conv_skip", relu=False)
        else:
            skip = x

        outputs = tf.add(conv_3, skip, name="fuse")
        return outputs

Attention Module

以Attention Module A为例

def attention_module_a(self, x, c_in, name, p=1, t=2, r=1):
    """
    两个skip connection
    """
    with tf.name_scope(name), tf.variable_scope(name):
        with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
            pre_post = x
            for idx in range(p):
                unit_name = "pre_post_{}".format(idx)
                pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)

        with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
            trunks = pre_post
            for idx in range(t):
                unit_name = "trunk_{}".format(idx + 2)
                trunks = self.residual_unit(trunks, c_in, c_in, unit_name)

        with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
            size_1 = pre_post.get_shape().as_list()[1:3]

            # Max_pooling
            pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")

            # Down res1
            down_res1 = pool_1
            for idx in range(r):
                unit_name = "down_res1_{}".format(idx + 1)
                down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)

            # # skip_1
            skip_1 = self.residual_unit(down_res1, c_in, c_in, "skip_conv_1")

            size_2 = down_res1.get_shape().as_list()[1:3]

            pool_2 = self.max_pool(down_res1, 3, 3, 2, 2, "pool_2")

            # Down res2
            down_res2 = pool_2
            for idx in range(r):
                unit_name = "down_res2_{}".format(idx + 1)
                down_res2 = self.residual_unit(down_res2, c_in, c_in, unit_name)

            # # skip_2
            skip_2 = self.residual_unit(down_res2, c_in, c_in, "skip_conv_2")

            size_3 = down_res2.get_shape().as_list()[1:3]

            pool_3 = self.max_pool(down_res2, 3, 3, 2, 2, "pool_3")

            # Down res3
            down_res3 = pool_3
            for idx in range(2 * r):
                unit_name = "down_res3_{}".format(idx + 1)
                down_res3 = self.residual_unit(down_res3, c_in, c_in, unit_name)

            # Interpolation_1
            interp_3 = self.upsample(down_res3, "interp_3", size_3)

            # Skip connection_1
            skip_connection_1 = tf.add(skip_2, interp_3, name="skip_connection_1")

            # Up res2
            up_res2 = skip_connection_1
            for idx in range(r):
                unit_name = "up_res2_{}".format(idx + 1)
                up_res2 = self.residual_unit(up_res2, c_in, c_in, unit_name)

            # Interpolation 2
            interp_2 = self.upsample(up_res2, "interp_2", size_2)

            # Skip connection_2
            skip_connection_2 = tf.add(skip_1, interp_2, name="skip_connection_2")

            # Up res1
            up_res1 = skip_connection_2
            for idx in range(r):
                unit_name = "up_res1_{}".format(idx + 1)
                up_res1 = self.residual_unit(up_res1, c_in, c_in, unit_name)

            # Interpolation 1
            interp_1 = self.upsample(up_res1, "interp_1", size_1)
            interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)

            # Linear
            linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
            linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")

            linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)

            # Sigmoid
            sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")

        # Fusing
        with tf.name_scope("fusing"), tf.variable_scope("fusing"):
            outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
            outputs = tf.add(trunks, outputs, name="fuse_add")

        # End_post
        with tf.name_scope("end_post"), tf.name_scope("end_post"):
            for idx in range(p):
                unit_name = "end_post_{}".format(idx + 1)
                outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
        return outputs

Residual Attention Network

def _setup(self):
    pre_head = self.pre_head(self.inputs, base, "head")
    logger("pre_head: {}".format(pre_head.shape))

    # Pre_res_1
    pre_res_1 = self.residual_unit(pre_head, base, base * 4, "pre_res_1", stride=1)
    logger("pre_res_1: {}".format(pre_res_1.shape))

    # Attention A
    attention_a = self.attention_module_a(pre_res_1, base * 4, "attention_a")
    logger("attention_a: {}".format(attention_a.shape))

    # Pre_res_2
    pre_res_2 = self.residual_unit(attention_a, base * 4, base * 8, "pre_res_2", stride=2)
    logger("pre_res_2: {}".format(pre_res_2.shape))

    # Attention B
    attention_b = self.attention_module_b(pre_res_2, base * 8, "attention_b")
    logger("attention_b: {}".format(attention_b.shape))

    # Pre_res_3
    pre_res_3 = self.residual_unit(attention_b, base * 8, base * 16, "pre_res_3", stride=2)
    logger("pre_res_4: {}".format(pre_res_3.shape))

    # Attention C
    attention_c = self.attention_module_c(pre_res_3, base * 16, "attention_c")
    logger("attention_c: {}".format(attention_c.shape))

    # Pre_res_4
    pre_res_4 = self.residual_unit(attention_c, base * 16, base * 32, "pre_res_4_1", stride=2)
    logger("pre_res_4_1: {}".format(pre_res_4.shape))
    pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_2", stride=1)
    logger("pre_res_4_2: {}".format(pre_res_4.shape))
    pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_3", stride=1)
    logger("pre_res_4_3: {}".format(pre_res_4.shape))

    pool_size = pre_res_4.get_shape().as_list()[1]

    # BN & ReLU
    pre_res_4_bn = self.batch_normal(pre_res_4, self.is_training, "pre_res_4_bn", activation_fn=tf.nn.relu)

    # Average pool
    ave_pool = self.ave_pool(pre_res_4_bn, pool_size, pool_size, 1, 1, "ave_pool")
    logger("ave_pool: {}".format(ave_pool.shape))

    # Outputs
    outputs = self.fc(ave_pool, self.num_classes, name="outputs", relu=False, bias_term=True)
    logger("outputs: {}".format(outputs.shape))

完整代码

class ResidualAttentionNetwork(BaseNet):
    def __init__(self, inputs, num_classes, keep_prob, is_training, trainable):
        super(ResidualAttentionNetwork, self).__init__(inputs, num_classes, keep_prob, is_training, trainable)
        self._setup()

    def _setup(self):
        self.outputs = self.ranet56(base=32)

    def ranet56(self, base):
        pre_head = self.pre_head(self.inputs, base, "head")
        logger("pre_head: {}".format(pre_head.shape))

        # Pre_res_1
        pre_res_1 = self.residual_unit(pre_head, base, base * 4, "pre_res_1", stride=1)
        logger("pre_res_1: {}".format(pre_res_1.shape))

        # Attention A
        attention_a = self.attention_module_a(pre_res_1, base * 4, "attention_a")
        logger("attention_a: {}".format(attention_a.shape))

        # Pre_res_2
        pre_res_2 = self.residual_unit(attention_a, base * 4, base * 8, "pre_res_2", stride=2)
        logger("pre_res_2: {}".format(pre_res_2.shape))

        # Attention B
        attention_b = self.attention_module_b(pre_res_2, base * 8, "attention_b")
        logger("attention_b: {}".format(attention_b.shape))

        # Pre_res_3
        pre_res_3 = self.residual_unit(attention_b, base * 8, base * 16, "pre_res_3", stride=2)
        logger("pre_res_4: {}".format(pre_res_3.shape))

        # Attention C
        attention_c = self.attention_module_c(pre_res_3, base * 16, "attention_c")
        logger("attention_c: {}".format(attention_c.shape))

        # Pre_res_4
        pre_res_4 = self.residual_unit(attention_c, base * 16, base * 32, "pre_res_4_1", stride=2)
        logger("pre_res_4_1: {}".format(pre_res_4.shape))
        pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_2", stride=1)
        logger("pre_res_4_2: {}".format(pre_res_4.shape))
        pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_3", stride=1)
        logger("pre_res_4_3: {}".format(pre_res_4.shape))

        pool_size = pre_res_4.get_shape().as_list()[1]

        # BN & ReLU
        pre_res_4_bn = self.batch_normal(pre_res_4, self.is_training, "pre_res_4_bn", activation_fn=tf.nn.relu)

        # Average pool
        ave_pool = self.ave_pool(pre_res_4_bn, pool_size, pool_size, 1, 1, "ave_pool")
        logger("ave_pool: {}".format(ave_pool.shape))

        # Outputs
        outputs = self.fc(ave_pool, self.num_classes, name="outputs", relu=False, bias_term=True)
        logger("outputs: {}".format(outputs.shape))
        return outputs

    def pre_head(self, x, c_out, name):
        with tf.name_scope(name), tf.variable_scope(name):
            conv = self.conv(x, 7, 7, 2, 2, c_out, "conv", relu=False)
            logger("pre_conv: {}".format(conv.shape))
            bn = self.batch_normal(conv, self.is_training, name="bn", activation_fn=tf.nn.relu)
            pool = self.max_pool(bn, 3, 3, 2, 2, name="pool", padding="SAME")
            logger("pre_pool: {}".format(pool.shape))
            return pool

    def residual_unit(self, x, c_in, c_out, name, stride=1):
        with tf.name_scope(name), tf.variable_scope(name):
            bn_1 = self.batch_normal(x, self.is_training, name="bn_1", activation_fn=tf.nn.relu)

            conv_1 = self.conv(bn_1, 1, 1, 1, 1, c_out // 4, "conv_1", relu=False)  # bn之后用了relu,所以这里relu设为False
            bn_2 = self.batch_normal(conv_1, self.is_training, name="bn_2", activation_fn=tf.nn.relu)

            conv_2 = self.conv(bn_2, 3, 3, stride, stride, c_out // 4, "conv_2", relu=False)
            bn_3 = self.batch_normal(conv_2, self.is_training, name="bn_3", activation_fn=tf.nn.relu)

            conv_3 = self.conv(bn_3, 1, 1, 1, 1, c_out, "conv_3", relu=False)

            if c_out != c_in or stride > 1:
                skip = self.conv(bn_1, 1, 1, stride, stride, c_out, "conv_skip", relu=False)
            else:
                skip = x

            outputs = tf.add(conv_3, skip, name="fuse")
            return outputs

    def attention_module_a(self, x, c_in, name, p=1, t=2, r=1):
        """
        两个skip connection
        """
        with tf.name_scope(name), tf.variable_scope(name):
            with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
                pre_post = x
                for idx in range(p):
                    unit_name = "pre_post_{}".format(idx)
                    pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)

            with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
                trunks = pre_post
                for idx in range(t):
                    unit_name = "trunk_{}".format(idx + 2)
                    trunks = self.residual_unit(trunks, c_in, c_in, unit_name)

            with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
                size_1 = pre_post.get_shape().as_list()[1:3]

                # Max_pooling
                pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")

                # Down res1
                down_res1 = pool_1
                for idx in range(r):
                    unit_name = "down_res1_{}".format(idx + 1)
                    down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)

                # # skip_1
                skip_1 = self.residual_unit(down_res1, c_in, c_in, "skip_conv_1")

                size_2 = down_res1.get_shape().as_list()[1:3]

                pool_2 = self.max_pool(down_res1, 3, 3, 2, 2, "pool_2")

                # Down res2
                down_res2 = pool_2
                for idx in range(r):
                    unit_name = "down_res2_{}".format(idx + 1)
                    down_res2 = self.residual_unit(down_res2, c_in, c_in, unit_name)

                # # skip_2
                skip_2 = self.residual_unit(down_res2, c_in, c_in, "skip_conv_2")

                size_3 = down_res2.get_shape().as_list()[1:3]

                pool_3 = self.max_pool(down_res2, 3, 3, 2, 2, "pool_3")

                # Down res3
                down_res3 = pool_3
                for idx in range(2 * r):
                    unit_name = "down_res3_{}".format(idx + 1)
                    down_res3 = self.residual_unit(down_res3, c_in, c_in, unit_name)

                # Interpolation_1
                interp_3 = self.upsample(down_res3, "interp_3", size_3)

                # Skip connection_1
                skip_connection_1 = tf.add(skip_2, interp_3, name="skip_connection_1")

                # Up res2
                up_res2 = skip_connection_1
                for idx in range(r):
                    unit_name = "up_res2_{}".format(idx + 1)
                    up_res2 = self.residual_unit(up_res2, c_in, c_in, unit_name)

                # Interpolation 2
                interp_2 = self.upsample(up_res2, "interp_2", size_2)

                # Skip connection_2
                skip_connection_2 = tf.add(skip_1, interp_2, name="skip_connection_2")

                # Up res1
                up_res1 = skip_connection_2
                for idx in range(r):
                    unit_name = "up_res1_{}".format(idx + 1)
                    up_res1 = self.residual_unit(up_res1, c_in, c_in, unit_name)

                # Interpolation 1
                interp_1 = self.upsample(up_res1, "interp_1", size_1)
                interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)

                # Linear
                linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
                linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")

                linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)

                # Sigmoid
                sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")

            # Fusing
            with tf.name_scope("fusing"), tf.variable_scope("fusing"):
                outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
                outputs = tf.add(trunks, outputs, name="fuse_add")

            # End_post
            with tf.name_scope("end_post"), tf.name_scope("end_post"):
                for idx in range(p):
                    unit_name = "end_post_{}".format(idx + 1)
                    outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
            return outputs

    def attention_module_b(self, x, c_in, name, p=1, t=2, r=1):
        """
        两个skip connection
        """
        with tf.name_scope(name), tf.variable_scope(name):
            with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
                pre_post = x
                for idx in range(p):
                    unit_name = "pre_post_{}".format(idx)
                    pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)

            with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
                trunks = pre_post
                for idx in range(t):
                    unit_name = "trunk_{}".format(idx + 2)
                    trunks = self.residual_unit(trunks, c_in, c_in, unit_name)

            with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
                size_1 = pre_post.get_shape().as_list()[1:3]

                # Max_pooling
                pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")

                # Down res1
                down_res1 = pool_1
                for idx in range(r):
                    unit_name = "down_res1_{}".format(idx + 1)
                    down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)

                # # skip_1
                skip_1 = self.residual_unit(down_res1, c_in, c_in, "skip_conv_1")

                size_2 = down_res1.get_shape().as_list()[1:3]

                pool_2 = self.max_pool(down_res1, 3, 3, 2, 2, "pool_2")

                # Down res2
                down_res2 = pool_2
                for idx in range(2 * r):
                    unit_name = "down_res2_{}".format(idx + 1)
                    down_res2 = self.residual_unit(down_res2, c_in, c_in, unit_name)

                # Interpolation 2
                interp_2 = self.upsample(down_res2, "interp_2", size_2)

                # Skip connection_2
                skip_connection_2 = tf.add(skip_1, interp_2, name="skip_connection_2")

                # Up res1
                up_res1 = skip_connection_2
                for idx in range(r):
                    unit_name = "up_res1_{}".format(idx + 1)
                    up_res1 = self.residual_unit(up_res1, c_in, c_in, unit_name)

                # Interpolation 1
                interp_1 = self.upsample(up_res1, "interp_1", size_1)
                interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)

                # Linear
                linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
                linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")

                linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)

                # Sigmoid
                sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")

            # Fusing
            with tf.name_scope("fusing"), tf.variable_scope("fusing"):
                outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
                outputs = tf.add(trunks, outputs, name="fuse_add")

            # End_post
            with tf.name_scope("end_post"), tf.name_scope("end_post"):
                for idx in range(p):
                    unit_name = "end_post_{}".format(idx + 1)
                    outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
            return outputs

    def attention_module_c(self, x, c_in, name, p=1, t=2, r=1):
        """
        两个skip connection
        """
        with tf.name_scope(name), tf.variable_scope(name):
            with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
                pre_post = x
                for idx in range(p):
                    unit_name = "pre_post_{}".format(idx)
                    pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)

            with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
                trunks = pre_post
                for idx in range(t):
                    unit_name = "trunk_{}".format(idx + 2)
                    trunks = self.residual_unit(trunks, c_in, c_in, unit_name)

            with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
                size_1 = pre_post.get_shape().as_list()[1:3]

                # Max_pooling
                pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")

                # Down res1
                down_res1 = pool_1
                for idx in range(2 * r):
                    unit_name = "down_res1_{}".format(idx + 1)
                    down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)

                # Interpolation 1
                interp_1 = self.upsample(down_res1, "interp_1", size_1)
                interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)

                # Linear
                linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
                linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")

                linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)

                # Sigmoid
                sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")

            # Fusing
            with tf.name_scope("fusing"), tf.variable_scope("fusing"):
                outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
                outputs = tf.add(trunks, outputs, name="fuse_add")

            # End_post
            with tf.name_scope("end_post"), tf.name_scope("end_post"):
                for idx in range(p):
                    unit_name = "end_post_{}".format(idx + 1)
                    outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
            return outputs

猜你喜欢

转载自blog.csdn.net/ZWX2445205419/article/details/89468070