版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/ZWX2445205419/article/details/89468070
Residual Attention Network
论文: https://arxiv.org/pdf/1704.06904.pdf
作者实现:https://github.com/fwang91/residual-attention-network
本文完全按照作者的CAFFE
源代码,使用TensorFlow重新实现
网络图
网络基础组件
class BaseNet:
def __init__(self, inputs, num_classes, is_training=True, trainable=True):
self.inputs = inputs
self.num_classes = num_classes
self.is_training = is_training
self.trainable = trainable
def _setup(self):
raise NotImplementedError
@staticmethod
def make_cpu_variable(name, shape, initializer, trainable=True):
with tf.variable_scope(name):
return tf.get_variable(name, shape, initializer=initializer, trainable=trainable)
def conv(self, x, k_h, k_w, s_h, s_w, c_out, name, relu, bias_term=False, padding="SAME", trainable=True):
with tf.name_scope(name), tf.variable_scope(name):
c_in = x.get_shape().as_list()[-1]
weights = self.make_cpu_variable("weights", [k_h, k_w, c_in, c_out],
initializer=layers.xavier_initializer_conv2d(),
trainable=trainable)
outputs = tf.nn.conv2d(x, weights, [1, s_h, s_w, 1], padding=padding)
if bias_term:
biases = self.make_cpu_variable("biases", [c_out],
initializer=tf.constant_initializer(0.001),
trainable=trainable)
outputs = tf.nn.bias_add(outputs, biases)
if relu:
outputs = tf.nn.relu(outputs)
return outputs
@staticmethod
def max_pool(x, k_h, k_w, s_h, s_w, name, padding="SAME"):
with tf.name_scope(name):
return tf.nn.max_pool(x, [1, k_h, k_w, 1], [1, s_h, s_w, 1], padding)
@staticmethod
def ave_pool(x, k_h, k_w, s_h, s_w, name, padding="VALID"):
with tf.name_scope(name):
return tf.nn.avg_pool(x, [1, k_h, k_w, 1], [1, s_h, s_w, 1], padding)
def fc(self, x, n_out, name, relu, bias_term=False, trainable=True):
with tf.name_scope(name), tf.variable_scope(name):
input_shape = x.get_shape().as_list()
assert len(input_shape) in [1, 4]
if len(input_shape) == 4:
dim = 1
for d in input_shape[1:]:
dim *= d
x = tf.reshape(x, [-1, dim])
else:
dim = input_shape[1]
weights = self.make_cpu_variable("weights", [dim, n_out],
initializer=tf.truncated_normal_initializer(stddev=0.001))
outputs = tf.matmul(x, weights)
if bias_term:
biases = self.make_cpu_variable("biases", [n_out],
initializer=tf.constant_initializer(0.001),
trainable=trainable)
outputs = tf.nn.bias_add(outputs, biases)
if relu:
outputs = tf.nn.relu(outputs)
return outputs
@staticmethod
def relu(x, name):
with tf.name_scope(name):
return tf.nn.relu(x)
@staticmethod
def dropout(x, keep_prob, name):
with tf.name_scope(name):
return tf.nn.dropout(x, keep_prob)
@staticmethod
def upsample(x, name, size):
with tf.name_scope(name):
return tf.image.resize_bilinear(x, size)
@staticmethod
def softmax(x, name):
with tf.name_scope(name):
return tf.nn.softmax(x)
@staticmethod
def batch_normal(x, is_training, name, activation_fn=None):
with tf.name_scope(name), tf.variable_scope(name):
return layers.batch_norm(x, decay=0.9,
zero_debias_moving_mean=True,
activation_fn=activation_fn,
is_training=is_training)
Residual Unit
def residual_unit(self, x, c_in, c_out, name, stride=1):
with tf.name_scope(name), tf.variable_scope(name):
bn_1 = self.batch_normal(x, self.is_training, name="bn_1", activation_fn=tf.nn.relu)
conv_1 = self.conv(bn_1, 1, 1, 1, 1, c_out // 4, "conv_1", relu=False) # bn之后用了relu,所以这里relu设为False
bn_2 = self.batch_normal(conv_1, self.is_training, name="bn_2", activation_fn=tf.nn.relu)
conv_2 = self.conv(bn_2, 3, 3, stride, stride, c_out // 4, "conv_2", relu=False)
bn_3 = self.batch_normal(conv_2, self.is_training, name="bn_3", activation_fn=tf.nn.relu)
conv_3 = self.conv(bn_3, 1, 1, 1, 1, c_out, "conv_3", relu=False)
if c_out != c_in or stride > 1:
skip = self.conv(bn_1, 1, 1, stride, stride, c_out, "conv_skip", relu=False)
else:
skip = x
outputs = tf.add(conv_3, skip, name="fuse")
return outputs
Attention Module
以Attention Module A为例
def attention_module_a(self, x, c_in, name, p=1, t=2, r=1):
"""
两个skip connection
"""
with tf.name_scope(name), tf.variable_scope(name):
with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
pre_post = x
for idx in range(p):
unit_name = "pre_post_{}".format(idx)
pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)
with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
trunks = pre_post
for idx in range(t):
unit_name = "trunk_{}".format(idx + 2)
trunks = self.residual_unit(trunks, c_in, c_in, unit_name)
with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
size_1 = pre_post.get_shape().as_list()[1:3]
# Max_pooling
pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")
# Down res1
down_res1 = pool_1
for idx in range(r):
unit_name = "down_res1_{}".format(idx + 1)
down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)
# # skip_1
skip_1 = self.residual_unit(down_res1, c_in, c_in, "skip_conv_1")
size_2 = down_res1.get_shape().as_list()[1:3]
pool_2 = self.max_pool(down_res1, 3, 3, 2, 2, "pool_2")
# Down res2
down_res2 = pool_2
for idx in range(r):
unit_name = "down_res2_{}".format(idx + 1)
down_res2 = self.residual_unit(down_res2, c_in, c_in, unit_name)
# # skip_2
skip_2 = self.residual_unit(down_res2, c_in, c_in, "skip_conv_2")
size_3 = down_res2.get_shape().as_list()[1:3]
pool_3 = self.max_pool(down_res2, 3, 3, 2, 2, "pool_3")
# Down res3
down_res3 = pool_3
for idx in range(2 * r):
unit_name = "down_res3_{}".format(idx + 1)
down_res3 = self.residual_unit(down_res3, c_in, c_in, unit_name)
# Interpolation_1
interp_3 = self.upsample(down_res3, "interp_3", size_3)
# Skip connection_1
skip_connection_1 = tf.add(skip_2, interp_3, name="skip_connection_1")
# Up res2
up_res2 = skip_connection_1
for idx in range(r):
unit_name = "up_res2_{}".format(idx + 1)
up_res2 = self.residual_unit(up_res2, c_in, c_in, unit_name)
# Interpolation 2
interp_2 = self.upsample(up_res2, "interp_2", size_2)
# Skip connection_2
skip_connection_2 = tf.add(skip_1, interp_2, name="skip_connection_2")
# Up res1
up_res1 = skip_connection_2
for idx in range(r):
unit_name = "up_res1_{}".format(idx + 1)
up_res1 = self.residual_unit(up_res1, c_in, c_in, unit_name)
# Interpolation 1
interp_1 = self.upsample(up_res1, "interp_1", size_1)
interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)
# Linear
linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")
linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)
# Sigmoid
sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")
# Fusing
with tf.name_scope("fusing"), tf.variable_scope("fusing"):
outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
outputs = tf.add(trunks, outputs, name="fuse_add")
# End_post
with tf.name_scope("end_post"), tf.name_scope("end_post"):
for idx in range(p):
unit_name = "end_post_{}".format(idx + 1)
outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
return outputs
Residual Attention Network
def _setup(self):
pre_head = self.pre_head(self.inputs, base, "head")
logger("pre_head: {}".format(pre_head.shape))
# Pre_res_1
pre_res_1 = self.residual_unit(pre_head, base, base * 4, "pre_res_1", stride=1)
logger("pre_res_1: {}".format(pre_res_1.shape))
# Attention A
attention_a = self.attention_module_a(pre_res_1, base * 4, "attention_a")
logger("attention_a: {}".format(attention_a.shape))
# Pre_res_2
pre_res_2 = self.residual_unit(attention_a, base * 4, base * 8, "pre_res_2", stride=2)
logger("pre_res_2: {}".format(pre_res_2.shape))
# Attention B
attention_b = self.attention_module_b(pre_res_2, base * 8, "attention_b")
logger("attention_b: {}".format(attention_b.shape))
# Pre_res_3
pre_res_3 = self.residual_unit(attention_b, base * 8, base * 16, "pre_res_3", stride=2)
logger("pre_res_4: {}".format(pre_res_3.shape))
# Attention C
attention_c = self.attention_module_c(pre_res_3, base * 16, "attention_c")
logger("attention_c: {}".format(attention_c.shape))
# Pre_res_4
pre_res_4 = self.residual_unit(attention_c, base * 16, base * 32, "pre_res_4_1", stride=2)
logger("pre_res_4_1: {}".format(pre_res_4.shape))
pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_2", stride=1)
logger("pre_res_4_2: {}".format(pre_res_4.shape))
pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_3", stride=1)
logger("pre_res_4_3: {}".format(pre_res_4.shape))
pool_size = pre_res_4.get_shape().as_list()[1]
# BN & ReLU
pre_res_4_bn = self.batch_normal(pre_res_4, self.is_training, "pre_res_4_bn", activation_fn=tf.nn.relu)
# Average pool
ave_pool = self.ave_pool(pre_res_4_bn, pool_size, pool_size, 1, 1, "ave_pool")
logger("ave_pool: {}".format(ave_pool.shape))
# Outputs
outputs = self.fc(ave_pool, self.num_classes, name="outputs", relu=False, bias_term=True)
logger("outputs: {}".format(outputs.shape))
完整代码
class ResidualAttentionNetwork(BaseNet):
def __init__(self, inputs, num_classes, keep_prob, is_training, trainable):
super(ResidualAttentionNetwork, self).__init__(inputs, num_classes, keep_prob, is_training, trainable)
self._setup()
def _setup(self):
self.outputs = self.ranet56(base=32)
def ranet56(self, base):
pre_head = self.pre_head(self.inputs, base, "head")
logger("pre_head: {}".format(pre_head.shape))
# Pre_res_1
pre_res_1 = self.residual_unit(pre_head, base, base * 4, "pre_res_1", stride=1)
logger("pre_res_1: {}".format(pre_res_1.shape))
# Attention A
attention_a = self.attention_module_a(pre_res_1, base * 4, "attention_a")
logger("attention_a: {}".format(attention_a.shape))
# Pre_res_2
pre_res_2 = self.residual_unit(attention_a, base * 4, base * 8, "pre_res_2", stride=2)
logger("pre_res_2: {}".format(pre_res_2.shape))
# Attention B
attention_b = self.attention_module_b(pre_res_2, base * 8, "attention_b")
logger("attention_b: {}".format(attention_b.shape))
# Pre_res_3
pre_res_3 = self.residual_unit(attention_b, base * 8, base * 16, "pre_res_3", stride=2)
logger("pre_res_4: {}".format(pre_res_3.shape))
# Attention C
attention_c = self.attention_module_c(pre_res_3, base * 16, "attention_c")
logger("attention_c: {}".format(attention_c.shape))
# Pre_res_4
pre_res_4 = self.residual_unit(attention_c, base * 16, base * 32, "pre_res_4_1", stride=2)
logger("pre_res_4_1: {}".format(pre_res_4.shape))
pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_2", stride=1)
logger("pre_res_4_2: {}".format(pre_res_4.shape))
pre_res_4 = self.residual_unit(pre_res_4, base * 32, base * 32, "pre_res_4_3", stride=1)
logger("pre_res_4_3: {}".format(pre_res_4.shape))
pool_size = pre_res_4.get_shape().as_list()[1]
# BN & ReLU
pre_res_4_bn = self.batch_normal(pre_res_4, self.is_training, "pre_res_4_bn", activation_fn=tf.nn.relu)
# Average pool
ave_pool = self.ave_pool(pre_res_4_bn, pool_size, pool_size, 1, 1, "ave_pool")
logger("ave_pool: {}".format(ave_pool.shape))
# Outputs
outputs = self.fc(ave_pool, self.num_classes, name="outputs", relu=False, bias_term=True)
logger("outputs: {}".format(outputs.shape))
return outputs
def pre_head(self, x, c_out, name):
with tf.name_scope(name), tf.variable_scope(name):
conv = self.conv(x, 7, 7, 2, 2, c_out, "conv", relu=False)
logger("pre_conv: {}".format(conv.shape))
bn = self.batch_normal(conv, self.is_training, name="bn", activation_fn=tf.nn.relu)
pool = self.max_pool(bn, 3, 3, 2, 2, name="pool", padding="SAME")
logger("pre_pool: {}".format(pool.shape))
return pool
def residual_unit(self, x, c_in, c_out, name, stride=1):
with tf.name_scope(name), tf.variable_scope(name):
bn_1 = self.batch_normal(x, self.is_training, name="bn_1", activation_fn=tf.nn.relu)
conv_1 = self.conv(bn_1, 1, 1, 1, 1, c_out // 4, "conv_1", relu=False) # bn之后用了relu,所以这里relu设为False
bn_2 = self.batch_normal(conv_1, self.is_training, name="bn_2", activation_fn=tf.nn.relu)
conv_2 = self.conv(bn_2, 3, 3, stride, stride, c_out // 4, "conv_2", relu=False)
bn_3 = self.batch_normal(conv_2, self.is_training, name="bn_3", activation_fn=tf.nn.relu)
conv_3 = self.conv(bn_3, 1, 1, 1, 1, c_out, "conv_3", relu=False)
if c_out != c_in or stride > 1:
skip = self.conv(bn_1, 1, 1, stride, stride, c_out, "conv_skip", relu=False)
else:
skip = x
outputs = tf.add(conv_3, skip, name="fuse")
return outputs
def attention_module_a(self, x, c_in, name, p=1, t=2, r=1):
"""
两个skip connection
"""
with tf.name_scope(name), tf.variable_scope(name):
with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
pre_post = x
for idx in range(p):
unit_name = "pre_post_{}".format(idx)
pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)
with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
trunks = pre_post
for idx in range(t):
unit_name = "trunk_{}".format(idx + 2)
trunks = self.residual_unit(trunks, c_in, c_in, unit_name)
with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
size_1 = pre_post.get_shape().as_list()[1:3]
# Max_pooling
pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")
# Down res1
down_res1 = pool_1
for idx in range(r):
unit_name = "down_res1_{}".format(idx + 1)
down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)
# # skip_1
skip_1 = self.residual_unit(down_res1, c_in, c_in, "skip_conv_1")
size_2 = down_res1.get_shape().as_list()[1:3]
pool_2 = self.max_pool(down_res1, 3, 3, 2, 2, "pool_2")
# Down res2
down_res2 = pool_2
for idx in range(r):
unit_name = "down_res2_{}".format(idx + 1)
down_res2 = self.residual_unit(down_res2, c_in, c_in, unit_name)
# # skip_2
skip_2 = self.residual_unit(down_res2, c_in, c_in, "skip_conv_2")
size_3 = down_res2.get_shape().as_list()[1:3]
pool_3 = self.max_pool(down_res2, 3, 3, 2, 2, "pool_3")
# Down res3
down_res3 = pool_3
for idx in range(2 * r):
unit_name = "down_res3_{}".format(idx + 1)
down_res3 = self.residual_unit(down_res3, c_in, c_in, unit_name)
# Interpolation_1
interp_3 = self.upsample(down_res3, "interp_3", size_3)
# Skip connection_1
skip_connection_1 = tf.add(skip_2, interp_3, name="skip_connection_1")
# Up res2
up_res2 = skip_connection_1
for idx in range(r):
unit_name = "up_res2_{}".format(idx + 1)
up_res2 = self.residual_unit(up_res2, c_in, c_in, unit_name)
# Interpolation 2
interp_2 = self.upsample(up_res2, "interp_2", size_2)
# Skip connection_2
skip_connection_2 = tf.add(skip_1, interp_2, name="skip_connection_2")
# Up res1
up_res1 = skip_connection_2
for idx in range(r):
unit_name = "up_res1_{}".format(idx + 1)
up_res1 = self.residual_unit(up_res1, c_in, c_in, unit_name)
# Interpolation 1
interp_1 = self.upsample(up_res1, "interp_1", size_1)
interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)
# Linear
linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")
linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)
# Sigmoid
sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")
# Fusing
with tf.name_scope("fusing"), tf.variable_scope("fusing"):
outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
outputs = tf.add(trunks, outputs, name="fuse_add")
# End_post
with tf.name_scope("end_post"), tf.name_scope("end_post"):
for idx in range(p):
unit_name = "end_post_{}".format(idx + 1)
outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
return outputs
def attention_module_b(self, x, c_in, name, p=1, t=2, r=1):
"""
两个skip connection
"""
with tf.name_scope(name), tf.variable_scope(name):
with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
pre_post = x
for idx in range(p):
unit_name = "pre_post_{}".format(idx)
pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)
with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
trunks = pre_post
for idx in range(t):
unit_name = "trunk_{}".format(idx + 2)
trunks = self.residual_unit(trunks, c_in, c_in, unit_name)
with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
size_1 = pre_post.get_shape().as_list()[1:3]
# Max_pooling
pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")
# Down res1
down_res1 = pool_1
for idx in range(r):
unit_name = "down_res1_{}".format(idx + 1)
down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)
# # skip_1
skip_1 = self.residual_unit(down_res1, c_in, c_in, "skip_conv_1")
size_2 = down_res1.get_shape().as_list()[1:3]
pool_2 = self.max_pool(down_res1, 3, 3, 2, 2, "pool_2")
# Down res2
down_res2 = pool_2
for idx in range(2 * r):
unit_name = "down_res2_{}".format(idx + 1)
down_res2 = self.residual_unit(down_res2, c_in, c_in, unit_name)
# Interpolation 2
interp_2 = self.upsample(down_res2, "interp_2", size_2)
# Skip connection_2
skip_connection_2 = tf.add(skip_1, interp_2, name="skip_connection_2")
# Up res1
up_res1 = skip_connection_2
for idx in range(r):
unit_name = "up_res1_{}".format(idx + 1)
up_res1 = self.residual_unit(up_res1, c_in, c_in, unit_name)
# Interpolation 1
interp_1 = self.upsample(up_res1, "interp_1", size_1)
interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)
# Linear
linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")
linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)
# Sigmoid
sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")
# Fusing
with tf.name_scope("fusing"), tf.variable_scope("fusing"):
outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
outputs = tf.add(trunks, outputs, name="fuse_add")
# End_post
with tf.name_scope("end_post"), tf.name_scope("end_post"):
for idx in range(p):
unit_name = "end_post_{}".format(idx + 1)
outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
return outputs
def attention_module_c(self, x, c_in, name, p=1, t=2, r=1):
"""
两个skip connection
"""
with tf.name_scope(name), tf.variable_scope(name):
with tf.name_scope("pre_post"), tf.variable_scope("pre_post"):
pre_post = x
for idx in range(p):
unit_name = "pre_post_{}".format(idx)
pre_post = self.residual_unit(pre_post, c_in, c_in, unit_name)
with tf.name_scope("trunk_branch"), tf.variable_scope("trunk_branch"):
trunks = pre_post
for idx in range(t):
unit_name = "trunk_{}".format(idx + 2)
trunks = self.residual_unit(trunks, c_in, c_in, unit_name)
with tf.name_scope("mask_branch"), tf.name_scope("mask_branch"):
size_1 = pre_post.get_shape().as_list()[1:3]
# Max_pooling
pool_1 = self.max_pool(pre_post, 3, 3, 2, 2, "pool_1")
# Down res1
down_res1 = pool_1
for idx in range(2 * r):
unit_name = "down_res1_{}".format(idx + 1)
down_res1 = self.residual_unit(down_res1, c_in, c_in, unit_name)
# Interpolation 1
interp_1 = self.upsample(down_res1, "interp_1", size_1)
interp_1 = self.batch_normal(interp_1, self.is_training, "interp_1_bn", tf.nn.relu)
# Linear
linear_1 = self.conv(interp_1, 1, 1, 1, 1, c_in, "linear_1", relu=False)
linear_1 = self.batch_normal(linear_1, self.is_training, "linear_1_bn")
linear_2 = self.conv(linear_1, 1, 1, 1, 1, c_in, "linear_2", relu=False)
# Sigmoid
sigmoid = tf.nn.sigmoid(linear_2, name="sigmoid")
# Fusing
with tf.name_scope("fusing"), tf.variable_scope("fusing"):
outputs = tf.multiply(trunks, sigmoid, name="fuse_mul")
outputs = tf.add(trunks, outputs, name="fuse_add")
# End_post
with tf.name_scope("end_post"), tf.name_scope("end_post"):
for idx in range(p):
unit_name = "end_post_{}".format(idx + 1)
outputs = self.residual_unit(outputs, c_in, c_in, unit_name)
return outputs