Unet学习笔记

最近在看Unet,记录一下。

论文地址

https://arxiv.org/pdf/1505.04597.pdf

网络结构


图上画的还是很清晰的,但是对于不了解Unet结构的人来说,可能还是有一些不清楚的地方。我这里结合我看的时候的疑问,来讲一下Unet的结构的一些问题。

  1. 可以看到,输入是572x572的,但是输出变成了388x388,这说明经过网络以后,输出的结果和原图不是完全对应的,这在计算loss和输出结果都可以得到体现。
  2. 蓝色箭头代表3x3的卷积操作,并且stride是1,padding策略是vaild,因此,每个该操作以后,featuremap的大小会减2。
  3. 红色箭头代表2x2的maxpooling操作,需要注意的是,此时的padding策略也是vaild(same 策略会在边缘填充0,保证featuremap的每个值都会被取到,vaild会忽略掉不能进行下去的pooling操作,而不是进行填充),这就会导致如果pooling之前featuremap的大小是奇数,那么就会损失一些信息 。
  4. 绿色箭头代表2x2的反卷积操作,这个只要理解了反卷积操作,就没什么问题,操作会将featuremap的大小乘2。
  5. 灰色箭头表示复制和剪切操作,可以发现,在同一层左边的最后一层要比右边的第一层要大一些,这就导致了,想要利用浅层的feature,就要进行一些剪切,也导致了最终的输出是输入的中心某个区域。
  6. 输出的最后一层,使用了1x1的卷积层做了分类。

网络结构代码

纸上得来终觉浅,还是要写个代码,才能了解细节部分。

下面这个代码,我参考了
https://github.com/jakeret/tf_unet

主要把每一层拆开了一步一步写了下来,虽然写的很啰嗦,但是看起来比较容易理解。另外,我忽略了一些对featuremap大小没有影响的操作,比如relu,bn,dropout等。

# conding:utf-8
from __future__ import print_function
import tensorflow as tf
import numpy as np

def weight_variable(shape, stddev=0.1, name="weight"):
    initial = tf.truncated_normal(shape, stddev=stddev)
    return tf.Variable(initial, name=name)

def bias_variable(shape, name="bias"):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial, name=name)

def get_w_and_b(kernel_size, input_feature_size, output_feature_size, name):
    w = weight_variable([kernel_size, kernel_size, input_feature_size, output_feature_size], name=name + '_w')
    b = bias_variable([output_feature_size], name=name + '_b')
    return w, b

def get_deconv_w_and_b(kernel_size, input_feature_size, output_feature_size, name):
    w = weight_variable([kernel_size, kernel_size, input_feature_size, output_feature_size], name=name + '_w')
    b = bias_variable([input_feature_size], name=name + '_b')
    return w, b

def conv2d(x, W, b):
    with tf.name_scope("conv2d"):
        conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID')
        conv_2d_b = tf.nn.bias_add(conv_2d, b)
        return conv_2d_b

def max_pool(x, n):
    return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='VALID')

def deconv2d(x, W, stride):
    with tf.name_scope("deconv2d"):
        x_shape = tf.shape(x)
        output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2])
        return tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding='VALID', name="conv2d_transpose")

def crop_and_concat(x1, x2):
    with tf.name_scope("crop_and_concat"):
        x1_shape = tf.shape(x1)
        x2_shape = tf.shape(x2)
        # offsets for the top left corner of the crop
        offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, 0]
        size = [-1, x2_shape[1], x2_shape[2], -1]
        x1_crop = tf.slice(x1, offsets, size)
        return tf.concat([x1_crop, x2], 3)

batch_size=1
# 16x+124
image_w=572
image_h=572
channel=3
pool_size=2
nclass=2
down_layer = {}


x = tf.placeholder(tf.float32, shape=(batch_size, image_w, image_h, channel))
print ('input x:\t\t\t', x)
# layer 0:
w0, b0 = get_w_and_b(3, channel, 64, 'layer0')
x = conv2d(x, w0, b0)
print ('layer 0 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 64, 64, 'layer0_1')
x = conv2d(x, w01, b01)
print ('layer 0 conv2 output:\t\t', x)
down_layer['layer0'] = x
# layer 1:
print ('='*100)
x = max_pool(x, pool_size)
print ('layer 1 pool output:\t\t', x)
w0, b0 = get_w_and_b(3, 64, 128, 'layer1')
x = conv2d(x, w0, b0)
print ('layer 1 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 128, 128, 'layer1_1')
x = conv2d(x, w01, b01)
print ('layer 1 conv2 output:\t\t', x)
down_layer['layer1'] = x
# layer 2:
print ('='*100)
x = max_pool(x, pool_size)
print ('layer 2 pool output:\t\t', x)
w0, b0 = get_w_and_b(3, 128, 256, 'layer2')
x = conv2d(x, w0, b0)
print ('layer 2 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 256, 256, 'layer2_1')
x = conv2d(x, w01, b01)
print ('layer 2 conv2 output:\t\t', x)
down_layer['layer2'] = x
# layer 3:
print ('='*100)
x = max_pool(x, pool_size)
print ('layer 3 pool output:\t\t', x)
w0, b0 = get_w_and_b(3, 256, 512, 'layer3')
x = conv2d(x, w0, b0)
print ('layer 3 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 512, 512, 'layer3_1')
x = conv2d(x, w01, b01)
print ('layer 3 conv2 output:\t\t', x)
down_layer['layer3'] = x
# layer 4:
print ('='*100)
x = max_pool(x, pool_size)
print ('layer 4 pool output:\t\t', x)
w0, b0 = get_w_and_b(3, 512, 1024, 'layer4')
x = conv2d(x, w0, b0)
print ('layer 4 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 1024, 1024, 'layer4_1')
x = conv2d(x, w01, b01)
print ('layer 4 conv2 output:\t\t', x)
down_layer['layer4'] = x
# up layer 3:
print ('='*100)
w0, b0 = get_deconv_w_and_b(pool_size, 512, 1024, 'uplayer_3')
x = deconv2d(x, w0, pool_size) + b0
print ('uplayer 3 deconv2d output:\t', x)
x = crop_and_concat(down_layer['layer3'], x)
print ('uplayer 3 crop&concat output:\t', x)
w0, b0 = get_w_and_b(3, 1024, 512, 'uplayer3')
x = conv2d(x, w0, b0)
print ('uplayer 3 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 512, 512, 'uplayer3_1')
x = conv2d(x, w01, b01)
print ('uplayer 3 conv2 output:\t\t', x)
# up layer 2:
print ('='*100)
w0, b0 = get_deconv_w_and_b(pool_size, 256, 512, 'uplayer_2')
x = deconv2d(x, w0, pool_size) + b0
print ('uplayer 2 deconv2d output:\t', x)
x = crop_and_concat(down_layer['layer2'], x)
print ('uplayer 2 crop&concat output:\t', x)
w0, b0 = get_w_and_b(3, 512, 256, 'uplayer2')
x = conv2d(x, w0, b0)
print ('uplayer 2 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 256, 256, 'uplayer2_1')
x = conv2d(x, w01, b01)
print ('uplayer 2 conv2 output:\t\t', x)
# up layer 1:
print ('='*100)
w0, b0 = get_deconv_w_and_b(pool_size, 128, 256, 'uplayer_1')
x = deconv2d(x, w0, pool_size) + b0
print ('uplayer 1 deconv2d output:\t', x)
x = crop_and_concat(down_layer['layer1'], x)
print ('uplayer 1 crop&concat output:\t', x)
w0, b0 = get_w_and_b(3, 256, 128, 'uplayer1')
x = conv2d(x, w0, b0)
print ('uplayer 1 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 128, 128, 'uplayer1_1')
x = conv2d(x, w01, b01)
print ('uplayer 1 conv2 output:\t\t', x)
# up layer 0:
print ('='*100)
w0, b0 = get_deconv_w_and_b(pool_size, 64, 128, 'uplayer_0')
x = deconv2d(x, w0, pool_size) + b0
print ('uplayer 0 deconv2d output:\t', x)
x = crop_and_concat(down_layer['layer0'], x)
print ('uplayer 0 crop&concat output:\t', x)
w0, b0 = get_w_and_b(3, 128, 64, 'uplayer0')
x = conv2d(x, w0, b0)
print ('uplayer 0 conv1 output:\t\t', x)
w01,b01 = get_w_and_b(3, 64, 64, 'uplayer0_1')
x = conv2d(x, w01, b01)
print ('uplayer 0 conv2 output:\t\t', x)
# output layer
print ('='*100)
w0,b0 = get_w_and_b(1, 64, nclass, 'output_layer')
x = conv2d(x, w0, b0)
print ('output layer out:\t\t', x)

输出结果:

我在前面说过,输入的大小最好满足一个条件,就是可以让每一层pooling操作前的featuremap的大小是偶数,这样就不会损失一些信息,并且crop的时候不会产生误差,这个条件也不难算,只要满足:
S = 16 x + 124 S = 16x + 124
其中S是featuremap的大小,x是一个自变量,可以用来计算想要的某个范围内合理的输入图片大小。

图片切分

Unet的结构中没有全连接,这就表示Unet的输入图片的大小其实是可以不固定的。无论训练还是测试的时候,都可以放一整张图片进去。不过呢,通常来说,一张图片扔进去,对显存还是有一定挑战的,并且,Unet最开始是为了处理医疗图像的,一般医疗图像都非常大。

因此,还可以使用另一个方法,那就是用一个滑动窗口把原图扫一边,使用原图的切片进行训练或测试。

上面图中,蓝色框可以认为是输入图片的大小,黄色区域就是过了网络以后预测的区域大小。因此,做预测的时候,只需要用滑动窗口让蓝色的区域都得到覆盖即可。

损失函数

论文中的损失函数首先是用了个pixel-wise softmax,其实这个没什么特别的,就是每个像素对应的输出单独做softmax,也就是做了w*h个softmax。

接下来的loss计算,我刚开始看的时候还是比较迷糊的,主要是写法的问题,其实就是一个交叉熵乘一个权重。

其中, x x 可以看作是某一个像素点, l ( x ) l(x) 表示 x x 这个点对应的类别label, p k ( x ) p_k(x) 表示在 x x 这个点的输出在类别k的softmax的激活值, p l ( x ) ( x ) p_{l(x)}(x) 代表什么呢?根据前面的说明就可以推断出来:点 x x 在对应的label给出的那个类别的输出的激活值。
回忆一下正常的交叉熵定义:
C = i y i ln a i C = -\sum_i{y_i \ln {a_i}}
可以发现除了 w ( x ) w(x) 和负号(应该需要负号的,我猜想是隐含在 w ( x ) w(x) 中了),在分类这个问题上,两个公式的意义其实是相同的,前面的直接选取了label对应的那个激活值,后面的公式在外面把非label对应的结果乘0了。所以就是一样的。

剩下的就是 w ( x ) w(x) 是个什么玩意儿了。这东西在论文里有对应的解释,但是那个公式应该还是使用于论文中的场景的。在我看来,这个权重更类似于一个超参数,你可以调整图像中某个区域的重要程度,在论文的场景中,大概是分割细胞,作者认为细胞边缘需要更大的权重,于是设计了对应的权重map,实际情况中,应该是需要自己来设计或这调整这个权重的。

结语

大概就这样,其实对于我在训的场景,还没训的很好,后面有更好的想法,再写吧。

发布了443 篇原创文章 · 获赞 149 · 访问量 55万+

猜你喜欢

转载自blog.csdn.net/qian99/article/details/85084686