tensorflow:双线性插值反卷积

首先生成3×3×3的黑色图片
"""
生成3×3×3黑色图像
"""
def produce_image():
    size = 3
    x, y = ogrid[:size, :size]  # 第一部分产生多行一列 第二部分产生一行多列
    z = x + y
    z = z[:, :, newaxis]  # 增加第三维
    # print(z)
    img = repeat(z, 3, 2)/12  # 在第三维上复制两遍
    # print(img.shape)
    # print(img)
    io.imshow(img, interpolation='none')
    io.show()
    return img

打印结果:

双线性插值反卷积代码如下:

"""
生成3×3×3黑色图像
"""
def produce_image():
    size = 3
    x, y = ogrid[:size, :size]  # 第一部分产生多行一列 第二部分产生一行多列
    z = x + y
    z = z[:, :, newaxis]  # 增加第三维
    # print(z)
    img = repeat(z, 3, 2)/12  # 在第三维上复制两遍
    # print(img.shape)
    # print(img)
    io.imshow(img, interpolation='none')
    io.show()
    return img

"""
上采样 双线性插值生成卷积核
"""
def upsampling_bilinear():
    #确定卷积核大小
    def get_kernel_size(factor):
        return 2*factor-factor%2
    # 创建相关矩阵
    def upsample_filt(size):
        factor=(size+1)//2
        if size%2==1:
            center=factor-1
        else:
            center=factor-0.5
        og=np.ogrid[:size,:size]
        # print(og)
        # print(og[0])
        # print(og[1])
        return (1-abs(og[0]-center)/factor)*(1-abs(og[1]-center)/factor)
    #进行上采样卷积核
    def bilinear_upsample_weights(factor,number_of_classes):
        filter_size=get_kernel_size(factor)
        weights=np.zeros((filter_size,filter_size,
                  number_of_classes,number_of_classes),dtype=np.float32)
        upsample_kernel=upsample_filt(filter_size)
        # print(upsample_kernel)
        for i in range(number_of_classes):
            weights[:,:,i,i]=upsample_kernel
            # print(weights[:,:,i,i])
        # print(weights)
        # print(weights.shape)
        return weights
    weights=bilinear_upsample_weights(3,3)
    return weights
if __name__ == '__main__':
    import tensorflow as tf
    # upsampling()
    # upsampling_bilinear()
    image=produce_image()
    img = tf.cast(image, dtype=tf.float32)
    img = tf.expand_dims(img, 0)  # 增加维度
    #产生卷积核
    kerenel=upsampling_bilinear()
    #反卷积处理
    res=tf.nn.conv2d_transpose(img,kerenel,output_shape=[1,9,9,3],strides=[1,3,3,1],padding='SAME')
    with tf.Session() as sess:
        img = sess.run(res)
    io.imshow(img[0, :, :, :] , interpolation='none')
    io.show()

打印结果:能较好恢复原图像

猜你喜欢

转载自blog.csdn.net/fanzonghao/article/details/81104094