Unsupervised Monocular Depth Estimation with Left-Right Consistency 论文阅读与实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_30665603/article/details/81182759

论文链接:https://arxiv.org/abs/1609.03677

 

源起:

       深度估计作为3D重建的重要组成部分有这一定的应用。比如,在使用百度地图时,其有场景重建功能,让我们可以在2D视角下实现3D“漫步”。要实现这个需求,一个基本的要求是判定走向方向对应物体的深度(“距离”)。本文就参考KITTI数据集给出深度估计示例。(数据采集场景与百度地图类似)

 

模型结构:

       先看一下使用非监督解决深度估计的原理,见下图:

       深度估计可以用有监督的方法也可以用无监督(或者一定意义上自监督)的学习方法。有监督的基本结构是给定单个摄像头采集的图片及对应像素点的深度,做“回归”,其难点在于“2D”特征估计3D结构,及标签的获取问题(定位物体深度的采集仪器易受环境影响,成本也较高)。

      一般的非监督深度估计方法的原理见上图,利用两个水平平行的摄像头,采集图片信息,估计出视差(使用网络结构),就可以利用视差(disparity)可以实现左右像素位置的移动一致性(对应loss),当网络训练的一致性被满足,我们得到的视差估计与深度的关系就通过如下表达式得到:

       上述公式各个部分见上图。

        网络结构图如下:

         这里的I对应图片输入特征,d对应视差估计,通过图片的编码解码结构,得到视差的估计,(4个scale)。能够想到的最基本的约束是视差平移后的一致性:

       这里SSIM为结构相似性:

        这个定义本身依据三角不等式,mu的要求是明显的,sigma的要求可见于Var(X-Y),利用Chebyshev不等式要求X、Y在弱收敛的意义下一致。(相应的实现是通过average pooling来计算相应的统计量——Local相关性约束)。

       另外一个要求是视差本身平移的一致性,这个要求已经蕴含在(左右)成对样本的要求中了,但在这里进一步加强:

       作者还加入了对视差平滑性的要求,similar with

(http://openaccess.thecvf.com/content_iccv_2013/papers/Heise_PM-Huber_PatchMatch_with_2013_ICCV_paper.pdf)

如下定义:

这个定义是很难的。。。。。。

 

一些细节:

  1. 视差的“加减”是水平插值。
  2. CNN编码解码结构如下:

      就卷积而言,一般的结构是加Relu及BatchNorm的,作者用Selu(利用激发函数“一次性”完成Normalization及非线性激发任务)进行替换,解决中间scale的视差不变问题。

      就上采样结构而言,一般的结构是反卷积,作者用插值上采样再卷积的方法替换反卷积。(反卷积本身就是稀疏化上采样后插值的产物)某种意义上是信息的扩增。

     解码端也有特征重构结构(类似U-Net)。

     自实现中没有在disp层加sigmoid激发。

 

     下面尝试使用PyTorch给出实现。

 

     数据集下载链接:

http://www.cvlibs.net/datasets/kitti/raw_data.php

    使用City数据集

数据导出:(尽量使用更多的数据集,模型拟合对于object类别的数量是敏感的)

from PIL import Image
import numpy as np
import glob
import os
import torch
import random

def read_single_image(path, resize = (512, 256), channel_num = 1,
                      is_cuda = False):
    img = Image.open(path)
    img = img.resize(resize)

    if channel_num == 1:
        img_array = np.transpose(np.array(img).astype(np.float32), (1, 0))
        if is_cuda:
            return torch.tensor(img_array / 255).cuda().unsqueeze(0)
        return torch.tensor(img_array / 255).unsqueeze(0)
    else:
        img_array = np.transpose(np.array(img).astype(np.float32), (2 ,1, 0))
        if is_cuda:
            return torch.tensor(img_array / 255).cuda()
        return torch.tensor(img_array / 255)

def batch_loader(batch_size = 2,
                 left_image_path_list = [r"E:\2011_09_26_drive_0001_sync\2011_09_26\2011_09_26_drive_0001_sync\image_02\data",
                                         r"E:\2011_09_26_drive_0002_sync\2011_09_26\2011_09_26_drive_0002_sync\image_02\data",
                                         r"E:\2011_09_26_drive_0005_sync\2011_09_26\2011_09_26_drive_0005_sync\image_02\data",
                                         r"E:\2011_09_26_drive_0009_sync\2011_09_26\2011_09_26_drive_0009_sync\image_02\data",
                                         r"E:\2011_09_26_drive_0011_sync\2011_09_26\2011_09_26_drive_0011_sync\image_02\data",
                                         r"E:\2011_09_26_drive_0013_sync\2011_09_26\2011_09_26_drive_0013_sync\image_02\data"],
                 channel_num = 1, single = True):

    assert channel_num in [1, 3]
    if channel_num == 1:
        left_image_path_list = list(map(lambda x: x.replace("image_02", "image_00"),left_image_path_list))



    def single_image_path_data_prepare(left_image_path):
        if channel_num == 1:
            right_image_path = left_image_path.replace(r"image_00", r"image_01")
        else:
            right_image_path = left_image_path.replace(r"image_02", r"image_03")
        assert os.path.exists(right_image_path)

        left_image_list = []
        right_image_list = []

        file_list = glob.glob(left_image_path + "\\" + "*")
        file_list = random.sample(file_list, len(file_list))

        for left_image_ele_path in file_list:
            path_tail = left_image_ele_path.split("\\")[-1]
            right_image_ele_path = right_image_path + "\\" + path_tail
            if not os.path.exists(right_image_ele_path):
                print("{} not exist !".format(right_image_ele_path))
                continue
            left_img_array = read_single_image(left_image_ele_path, channel_num = channel_num)
            right_img_array = read_single_image(right_image_ele_path, channel_num = channel_num)
            left_image_list.append(left_img_array)
            right_image_list.append(right_img_array)
            if len(left_image_list) == batch_size:
                left_tensor = torch.stack(left_image_list, dim = 0)
                right_tensor = torch.stack(right_image_list, dim = 0)
                yield left_tensor, right_tensor
                left_image_list = []
                right_image_list = []
        raise RuntimeError("one epoch end !")

    while True:
        try:
            left_image_path_list = random.sample(left_image_path_list, len(left_image_path_list))
            for left_image_path in left_image_path_list:
                print(left_image_path)
                for left_tensor, right_tensor in single_image_path_data_prepare(left_image_path):
                    yield left_tensor, right_tensor
                    if single:
                        break
                if single:
                    break
        except:
            yield None
            print("one epoch end !")
            continue

if __name__ == "__main__":
    pass

权重初始化及插值函数:

(插值部分原型见:

https://github.com/mrharicot/monodepth/blob/master/bilinear_sampler.py)

import torch
import torch.nn as nn

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

def bilinear_sampler_1d_h(input_image, x_offset, channel_major = True,
                          is_cuda = False):
    # when set channal major is True return shape [b, c, h, w]
    # must swap the dim w, h before input

    _num_batch ,_height, _width = map(int ,x_offset.size())

    input_image_size_list = list(map(int ,input_image.size()))
    assert input_image_size_list[0] == _num_batch and input_image_size_list[2] == _height and \
           input_image_size_list[3] == _width
    _num_channels = input_image_size_list[1]

    # transform to channel major to false format (tensorflow format)
    # [batch, height, width, channel_num]
    input_images = input_image.permute(0, 2, 3, 1)
    _height_f = torch.tensor(_height).type(torch.float)
    _width_f = torch.tensor(_width).type(torch.float)
    if is_cuda:
        _height_f = _height_f.cuda()
        _width_f = _width_f.cuda()

    def _repeat(x, n_repeat):
        rep = x.unsqueeze(1).repeat([1, n_repeat])
        return rep.reshape([-1])

    def _interpolate(im, x, y):
        _edge_size = 1

        h_head = torch.zeros([_num_batch, _height, 1, _num_channels])
        h_tail = torch.zeros([_num_batch, _height, 1, _num_channels])

        im = torch.cat([h_head, im, h_tail], dim = 2)

        w_head = torch.zeros([_num_batch, 1, _width + 2, _num_channels])
        w_tail = torch.zeros([_num_batch, 1, _width + 2, _num_channels])
        im = torch.cat([w_head, im, w_tail], dim = 1)

        x = x + _edge_size
        y = y + _edge_size

        _edge_size = 1

        x[x < 0.0] = 0.0
        x[x > _width_f - 1 + 2 * _edge_size] = _width_f - 1 + 2 * _edge_size
        x0_f = torch.floor(x)
        y0_f = torch.floor(y)
        x1_f = x0_f + 1

        x0 = x0_f.type(torch.int32)
        y0 = y0_f.type(torch.int32)

        x1 = (torch.min(x1_f, torch.ones_like(x1_f) * (_width_f - 1 + 2 * _edge_size))).type(torch.int32)

        dim2 = (_width + 2 * _edge_size)
        dim1 = (_width + 2 * _edge_size) * (_height + 2 * _edge_size)

        base = _repeat((torch.range(0, _num_batch - 1) * dim1).type(torch.int32), _height * _width)
        if is_cuda:
            base = base.cuda()

        base_y0 = base + y0 * dim2
        idx_l = base_y0 + x0
        idx_r = base_y0 + x1

        weight_l = (x1_f - x).unsqueeze(1)
        weight_r = (x - x0_f).unsqueeze(1)

        im_flat = im.reshape([-1, _num_channels])
        idx_l[idx_l < 0] = 0
        idx_l[idx_l >= _num_batch * _height * _width] = _num_batch * _height * _width - 1
        idx_r[idx_r < 0] = 0
        idx_r[idx_r >= _num_batch * _height * _width] = _num_batch * _height * _width - 1

        pix_l = torch.gather(im_flat, 0, idx_l.unsqueeze(-1).repeat([1, _num_channels]).type(torch.int64))
        pix_r = torch.gather(im_flat, 0, idx_r.unsqueeze(-1).repeat([1, _num_channels]).type(torch.int64))
        return weight_l.repeat([1, _num_channels]) * pix_l + weight_r.repeat([1, _num_channels]) * pix_r

    def meshgrid_x_flat_y_flat(batch_num ,width, height, x_offset):
        # [batch_size, ]
        width_part = torch.range(0 ,width - 1).unsqueeze(0).repeat(height, 1).unsqueeze(0).repeat(batch_num, 1, 1)
        height_part = torch.range(0 ,height - 1).unsqueeze(0).repeat(width, 1).unsqueeze(0).repeat(batch_num, 1, 1)
        height_part = height_part.permute(0, 2, 1)

        x_t_flat = width_part.reshape([-1])
        y_t_flat = height_part.reshape([-1])
        if is_cuda:
            x_t_flat = x_t_flat.cuda()
            y_t_flat = y_t_flat.cuda()

        x_t_flat = x_t_flat + x_offset.reshape([-1]) * _width_f

        return x_t_flat, y_t_flat

    def _transform(input_images, x_offset):
        # produce meshgrid
        x_t_flat, y_t_flat = meshgrid_x_flat_y_flat(_num_batch,_width,_height,x_offset)
        input_transformed = _interpolate(input_images, x_t_flat, y_t_flat)

        output = input_transformed.reshape([_num_batch, _height, _width, _num_channels])
        return output

    output = _transform(input_images, x_offset)
    if channel_major:
        output = output.permute(0, 3, 1, 2)

    return output

模型:(这里仅仅做了深度效果图,并没有解析焦距及baseline)

import torch
import torch.nn as nn
import torch.nn.functional as F

from model_construct.model_utils_cuda import normal_init, bilinear_sampler_1d_h
from data_preprocess.data_loader import batch_loader
from torch.nn import UpsamplingNearest2d
from torch.optim import Adam

import numpy as np
from PIL import Image
import uuid
import os

torch.set_num_threads(4)

class MonoDepthCNN(nn.Module):
    def __init__(self, channel_num = 1):
        assert channel_num in [1, 3]
        super(MonoDepthCNN, self).__init__()

        # encoder part #########################
        self.conv1 = nn.Sequential(nn.Conv2d(channel_num, 32, kernel_size=7, stride=2, padding=3), nn.SELU())
        self.conv1b = nn.Sequential(nn.Conv2d(32, 32, kernel_size=7, stride=1, padding=3), nn.SELU())

        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), nn.SELU())
        self.conv2b = nn.Sequential(nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2), nn.SELU())

        self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.SELU())
        self.conv3b = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.SELU())
        self.conv4b = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.conv5 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), nn.SELU())
        self.conv5b = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.conv6 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.SELU())
        self.conv6b = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.conv7 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.SELU())
        self.conv7b = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.SELU())

        # decoder part #################################
        self.upconv7 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv7 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.upconv6 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv6 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.upconv5 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv5 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.upconv4 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv4 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nn.SELU())
        self.disp4 = nn.Sequential(nn.Conv2d(128, 2, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.upconv3 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv3 = nn.Sequential(nn.Conv2d(130, 64, kernel_size=3, stride=1, padding=1), nn.SELU())
        self.disp3 = nn.Sequential(nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.upconv2 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv2 = nn.Sequential(nn.Conv2d(66, 32, kernel_size=3, stride=1, padding=1), nn.SELU())
        self.disp2 = nn.Sequential(nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1), nn.SELU())

        self.upconv1 = nn.Sequential(UpsamplingNearest2d(scale_factor=4),
                                     nn.Sequential(nn.Conv2d(32, 16, kernel_size=3, stride=2, padding=1), nn.SELU()))
        self.iconv1 = nn.Sequential(nn.Conv2d(18, 16, kernel_size=3, stride=1, padding=1), nn.SELU())
        self.disp1 = nn.Sequential(nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1), nn.SELU())

        ###################################

        # loss corrrelate ops
        self.ssim_pooling = nn.AvgPool2d(kernel_size=3, stride=2)

    def encoder(self, input):
        self.conv1b_output = self.conv1b(self.conv1(input))
        self.conv2b_output = self.conv2b(self.conv2(self.conv1b_output))
        self.conv3b_output = self.conv3b(self.conv3(self.conv2b_output))
        self.conv4b_output = self.conv4b(self.conv4(self.conv3b_output))
        self.conv5b_output = self.conv5b(self.conv5(self.conv4b_output))
        self.conv6b_output = self.conv6b(self.conv6(self.conv5b_output))
        self.conv7b_output = self.conv7b(self.conv7(self.conv6b_output))
        return self.conv7b_output

    def decoder(self, input):
        self.upconv7_output = self.upconv7(input)
        self.iconv7_output = self.iconv7(torch.cat([self.upconv7_output, self.conv6b_output], dim = 1))

        self.upconv6_output = self.upconv6(self.iconv7_output)
        self.iconv6_output = self.iconv6(torch.cat([self.upconv6_output, self.conv5b_output], dim = 1))

        self.upconv5_output = self.upconv5(self.iconv6_output)
        self.iconv5_output = self.iconv5(torch.cat([self.upconv5_output, self.conv4b_output], dim = 1))

        self.upconv4_output = self.upconv4(self.iconv5_output)
        self.iconv4_output = self.iconv4(torch.cat([self.upconv4_output, self.conv3b_output], dim = 1))
        self.disp4_output = self.disp4(self.iconv4_output)

        self.upconv3_output = self.upconv3(self.iconv4_output)
        self.iconv3_output = self.iconv3(torch.cat([self.upconv3_output, self.conv2b_output,
                                                    F.upsample_nearest(self.disp4_output, scale_factor=2)], dim = 1))
        self.disp3_output = self.disp3(self.iconv3_output)

        self.upconv2_output = self.upconv2(self.iconv3_output)
        self.iconv2_output = self.iconv2(torch.cat([self.upconv2_output, self.conv1b_output,
                                                    F.upsample_nearest(self.disp3_output, scale_factor=2)], dim = 1))
        self.disp2_output = self.disp2(self.iconv2_output)

        self.upconv1_output = self.upconv1(self.iconv2_output)
        self.iconv1_output = self.iconv1(torch.cat([self.upconv1_output,
                                                    F.upsample_nearest(self.disp2_output, scale_factor=2)], dim = 1))
        self.disp1_output = self.disp1(self.iconv1_output)


    def interpolate_input(self, input, size):
        return F.upsample_bilinear(input, size=size)

    def Appearance_Matching_Loss(self, pred_input, true_input, c1 = 0.1, c2 = 0.1,
                                 alpha = 0.85):
        # pred_input true_input [b, c, w, h]
        pred_mu = self.ssim_pooling(pred_input)
        true_mu = self.ssim_pooling(true_input)
        pred_sigma_square = self.ssim_pooling(pred_input ** 2) - pred_mu ** 2
        true_sigma_square = self.ssim_pooling(true_input ** 2) - true_mu ** 2
        pred_true_mu = self.ssim_pooling(pred_input * true_input) - pred_mu * true_mu
        ssim = (2 * pred_mu * true_mu + c1) * (2 * pred_true_mu + c2) / (pred_mu ** 2 + true_mu ** 2 + c1) / (pred_sigma_square + true_sigma_square + c2)

        return alpha * torch.mean( (1 - ssim) / 2) + (1 - alpha) * torch.mean( (pred_input - true_input) ** 2)

    def Left_Right_Disparity_Consistency(self, dl, dr):
        drl = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(dr.unsqueeze(1)), self.swap_dim(dl)))
        dlr = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(dl.unsqueeze(1)), self.swap_dim(dr)))
        return torch.mean(torch.abs(dl - drl)) + torch.mean(torch.abs(dr - dlr))

    def Disparity_Smoothness_Loss(self, input, d):
        partial_w = input[:, :, :-1, :] - input[:, :, 1:, :]
        partial_h = input[:, :, :, :-1] - input[:, :, :, 1:]

        d_part_w = d.unsqueeze(1)[:, :, :-1, :]
        input_part_w = input[:, :, :-1, :]

        d_part_h = d.unsqueeze(1)[:, :, :, :-1]
        input_part_h = input[:, :, :, :-1]

        return torch.mean(torch.abs(partial_w * d_part_w) * torch.exp(-1 * (partial_w * input_part_w) ** 2)) + \
               torch.mean(torch.abs(partial_h * d_part_h) * torch.exp(-1 * (partial_h * input_part_h) ** 2))

    def swap_dim(self, input):
        size_len = len(input.size())
        assert size_len in [3, 4]
        if size_len == 3:
            return input.permute(0, 2, 1)
        else:
            return input.permute(0, 1, 3, 2)

    def forward(self, left_input, right_input):
        # encode decode process
        self.decoder(self.encoder(left_input))

        left_input_4 = self.interpolate_input(left_input, (64, 32))
        dr4 = self.disp4_output[:, 0, :, :]
        right_pred_4 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(left_input_4), self.swap_dim(dr4)))

        left_input_3 = self.interpolate_input(left_input, (128, 64))
        dr3 = self.disp3_output[:, 0, :, :]
        right_pred_3 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(left_input_3), self.swap_dim(dr3)))

        left_input_2 = self.interpolate_input(left_input, (256, 128))
        dr2 = self.disp2_output[:, 0, :, :]
        right_pred_2 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(left_input_2), self.swap_dim(dr2)))

        left_input_1 = left_input
        dr1 = self.disp1_output[:, 0, :, :]
        right_pred_1 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(left_input_1), self.swap_dim(dr1)))

        #############################################################
        right_input_4 = self.interpolate_input(right_input, (64, 32))
        dl4 = self.disp4_output[:, 1, :, :]
        left_pred_4 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(right_input_4), -1 * self.swap_dim(dl4)))

        right_input_3 = self.interpolate_input(right_input, (128, 64))
        dl3 = self.disp3_output[:, 1, :, :]
        left_pred_3 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(right_input_3), -1 * self.swap_dim(dl3)))

        right_input_2 = self.interpolate_input(right_input, (256, 128))
        dl2 = self.disp2_output[:, 1, :, :]
        left_pred_2 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(right_input_2), -1 * self.swap_dim(dl2)))

        right_input_1 = right_input
        dl1 = self.disp1_output[:, 1, :, :]
        left_pred_1 = self.swap_dim(bilinear_sampler_1d_h(self.swap_dim(right_input_1), -1 * self.swap_dim(dl1)))

        appearance_matching_loss = self.Appearance_Matching_Loss(left_pred_4, left_input_4) + \
                                   self.Appearance_Matching_Loss(left_pred_3, left_input_3) + \
                                   self.Appearance_Matching_Loss(left_pred_2, left_input_2) + \
                                   self.Appearance_Matching_Loss(left_pred_1, left_input_1) + \
                                   self.Appearance_Matching_Loss(right_pred_4, right_input_4) + \
                                   self.Appearance_Matching_Loss(right_pred_3, right_input_3) + \
                                   self.Appearance_Matching_Loss(right_pred_2, right_input_2) + \
                                   self.Appearance_Matching_Loss(right_pred_1, right_input_1)

        left_right_disparity_consistency_loss = self.Left_Right_Disparity_Consistency(dl4, dr4) + \
                                                self.Left_Right_Disparity_Consistency(dl3, dr3) + \
                                                self.Left_Right_Disparity_Consistency(dl2, dr2) + \
                                                self.Left_Right_Disparity_Consistency(dl1, dr1)

        disparity_smoothness_loss = self.Disparity_Smoothness_Loss(left_input_4, dl4) + \
                                    self.Disparity_Smoothness_Loss(left_input_3, dl3) + \
                                    self.Disparity_Smoothness_Loss(left_input_2, dl2) + \
                                    self.Disparity_Smoothness_Loss(left_input_1, dl1) + \
                                    self.Disparity_Smoothness_Loss(right_input_4, dr4) + \
                                    self.Disparity_Smoothness_Loss(right_input_3, dr3) + \
                                    self.Disparity_Smoothness_Loss(right_input_2, dr2) + \
                                    self.Disparity_Smoothness_Loss(right_input_1, dr1)

        total_loss = appearance_matching_loss + left_right_disparity_consistency_loss + \
                     disparity_smoothness_loss

        return total_loss


    def weight_init(self, mean = 0.0, std = float(1e-2)):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

def predict_depth(model):
    # [b, 1, w, h]
    disp_output = model.disp1_output[:, 0, :, :]
    cp = disp_output.data.numpy().copy()

    ################ cp rescale
    cp = 0.01 / cp

    for sample in cp:
        req_depth =  np.transpose((((sample).squeeze() * 255).astype(np.uint8)), [1, 0])
        img = Image.fromarray(req_depth)
        img.save(r"E:\valid_new_0\{}.jpg".format(str(uuid.uuid1())))

def model_train():
    train_gen = batch_loader(single=False, channel_num=1)

    if os.path.exists(r"C:\tempCodingUsage\python\MonoDepth\mono_serlize.pkl"):
        model_ext = torch.load(r"C:\tempCodingUsage\python\MonoDepth\mono_serlize.pkl")["model_ext"]
        print("load end")
    else:
        model_ext = MonoDepthCNN()
        model_ext.weight_init()
        model_ext.train()
        print("init end")

    optimizer = Adam(model_ext.parameters(), lr = 0.0001)
    print("model_init end")

    epoch = 0
    step = 0

    while True:
        input = train_gen.__next__()
        if input is None:
            print("one epoch end !")
            epoch += 1
            continue

        left_input, right_input = input
        loss = model_ext(left_input, right_input)

        optimizer.zero_grad()
        loss.backward(retain_graph = True)
        optimizer.step()

        print("epoch : {} step : {} loss : {}".format(epoch, step, loss))
        if step % 10 == 0:
            print("call save" + "-" * 1000)
            with open(r"C:\tempCodingUsage\python\MonoDepth\mono_serlize.pkl", "wb") as f:
                torch.save({
                    "model_ext": model_ext
                }, f)

            print("call img valid" + "*" * 1000)
            print()
            predict_depth(model_ext)

        step += 1

model_train()

if __name__ == "__main__":
    pass

下面是一个训练出的深度图片例子:

在训练过程中,图片实景的距离有一定影响,近景先收敛,类别的丰富性也有一定要求。

 

猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/81182759