Ordinal Depth Supervision for 3D Human Pose Estimation 论文阅读与实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_30665603/article/details/81506229

论文连接:https://arxiv.org/abs/1805.04095

论文大意:

       基于heatmap的特征点检测方式可以看作是解决2D姿态估计的一种基础,姿态估计较特征点检测而言的更高要求是识别出特征点后解决不同姿态的“覆盖”问题,如手臂向前伸长可能挡住躯干,故3D姿态估计可以看作是一个更高的需求。

       本文要解决的是当对于ground truth 3d信息不甚精确(这里特指对于特征点只有相对深度,而没有绝对深度值时)的姿态估计问题。

       下面是论文的图示:

       如上图在训练的label层次我们仅有特征点在2D情况下的ground truth及特征点的相对深度,要得到有遮挡的3D pose估计结果。

       作者将这个问题分解成两步,即分别用ConvNet结构估计出2D结果及深度,所用到的ConvNet结构为已有的2D特征点检测或姿态估计方法,其是一种较CPM效果更好的漏斗结构,论文见下面链接:https://arxiv.org/abs/1603.06937

       由于其是给出后续估计的最重要基础,故下面对网络结构进行简要介绍,先给出网络结

构图:

      这是一个漏斗结构,文中的输入为resize为256x256的图片,先使用一个卷积结构降维至64x64,之后对于降维后的特征不断过“漏斗”编码解码结构,中间的漏斗结构的细节图如下:

      是一种带跳连的编码解码结构。每一个方块都是一个residual block(在下面的实现上去掉了卷积层的BN,保留residual block BN,BN关于训练及预测是同一状态的,这些调整是在训练时通过观察给出的),用以执行size及通道数的变换。(这里residual block的通道细节,不同的选择有别,有一些实现给出的是在编码解码过程保持通道数不变的形式,在后面实现的过程中选择常用的通道随下采样递增、随上采样递减的形式)

       有了漏斗结构,下面就涉及如何安排heatmap的loss问题,heatmap的loss被放在了“漏斗”间,见下图:(这里对每一个漏斗间都见loss的形式是一种多stage loss的情况,其原因可以参见有相似设计的CPM中的多stage loss,其为了解决随着网络增深的梯度消失问题)

       在两个漏斗间使用1x1卷积(蓝色部分)提取出与通道数相同的张量结构直接对heatmap输入进行回归。

       这样就基本解决了2D特征点的定位问题,下面转到对于相对深度的估计上。

       文中对于相对深度的处理也是从ConvNet中得到每一个特征点的相对深度估计值(这里的估计值是相对的,及只有“深浅”的区别)。在后面的实现中采用的是从ConvNet的heatmap回归部分类似地“分出”一个张量结构,之后对于张量进前向神经网络(最后输出层得到的是深度的“估计值”,为下图中的z),给出关于相对深度的回归,先假设已经得到了各个特征点神对的估计值,其对应的loss如下:(为了方便直接进行论文截图)

       上面损失的意义在logistic回归的层面上是明显的,加入相似z l2 loss的原因可以在训练时窥得。(当认为相临近的点很少时,如果我们在生成相对深度标签给与了更为严格的“阈值”使得判定为r(i,j)=0的样本量较少时,训练出的结果在深度维度就较为不稳定,容易出现在看起来人“站不住”的情况——深度的量纲倾斜,所以l2 loss及相应标签的调参是重要的)

       这样就可以得到相对深度z的估计了。之后将两方面loss进行合并就可以得到基本的姿态估计结果。(这里不涉及论文第三部分.Volumetric prediction for 3D pose部分)

       下面尝试给出该模型的实现,实现使用的数据为Human3.6M Dataset使用的是如下工程的下载版本:https://github.com/geopavlakos/ordinal-pose3d(该工程仅有模型的使用而不包含模型训练)

       由于使用的是这个版本的数据,故有关图片裁剪部分的结果要类似的调用相同的过程,下面的图片裁剪逻辑可以看成是上面工程Torch7的对应numpy版本。(各位代码“鬼裁”注意下Lua的下标是从1开始就好了)

       下面给出实现:

 

       首先给出图片裁剪及heatmap标签的生成函数:(由于实测的时候这部分很消耗性能,故先用numpy给出,后使用Cython进行编译调用,进行简单修改及编译后单个图片(1000x1200)的heatmap生成时间从1.5s可以下降到0.5s)

import h5py
import numpy as np
from numpy.linalg import inv
from PIL import Image
import math

def CenterGaussianHeatMap_cc(int img_height,int img_width,int c_x,int c_y,float variance):
    gaussian_map = np.zeros((img_height, img_width))
    for x_p in range(img_width):
        for y_p in range(img_height):
            gaussian_map[y_p, x_p] = math.exp(-1 * (((x_p - c_x) * (x_p - c_x) +
                                                     (y_p - c_y) * (y_p - c_y)) / 2.0 / variance / variance))
    return gaussian_map

def getTransform(center, scale, rot, res):
    h = 200 * scale
    t = np.eye(3)

    t[0][0] = res / h
    t[1][1] = res / h

    t[0][2] = res * (-center[0] / h + 0.5)
    t[1][2] = res * (-center[1] / h + 0.5)

    if rot != 0:
        rot = -rot
        r = np.eye(3)
        ang = rot * np.pi / 180
        s = np.sin(ang)
        c = np.cos(ang)
        r[0][0] = c
        r[0][1] = -s
        r[1][0] = s
        r[1][1] = c

        t_ = np.eye(3)
        t_[0][2] = -res / 2
        t_[1][2] = -res / 2
        t_inv = np.eye(3)
        t_inv[0][2] = res / 2
        t_inv[1][2] = res / 2
        t = t_inv * r * t_ * t

    return t


def transform(pt, center, scale, rot, res, invert):
    pt_ = np.ones(3)
    pt_[0], pt_[1] = pt[0] - 1, pt[1] - 1

    t = getTransform(center, scale, rot, res)
    if invert:
        t = inv(t)
    new_point = np.dot(t,pt_[:, np.newaxis])[:2, :] + 1e-4
    new_point = np.squeeze(new_point)

    return new_point.astype(np.int32) + 1


def crop(img, center, scale, rot = 0, res = 256):
    assert isinstance(img, np.ndarray)

    ul = transform([1, 1], center, scale, 0, res, True)
    br = transform([res + 1, res + 1], center, scale, 0, res, True)

    def norm(input_array):
        return np.sqrt(np.sum(input_array * input_array))

    pad = np.floor(norm((ul - br).astype(np.float32)) / 2 - (br[0] - ul[0]) / 2)

    if rot != 0:
        ul = ul - pad
        br = br + pad

    newDim, newImg, ht, wd = [None] * 4

    if len(img.shape) > 2:
        newDim = np.array([img.shape[0], br[1] - ul[1], br[0] - ul[0]]).astype(np.int32)
        newImg = np.zeros([newDim[0], newDim[1], newDim[2]])
        ht = img.shape[1]
        wd = img.shape[2]
    else:
        newDim = np.array([br[1] - ul[1], br[0] - ul[0]]).astype(np.int32)
        newImg = np.zeros([newDim[0], newDim[1]])
        ht = img.shape[0]
        wd = img.shape[1]

    newX = np.array([np.max([1, -ul[0] + 2]), np.min([br[0], wd + 1]) - ul[0]])
    newY = np.array([np.max([1, -ul[1] + 2]), np.min([br[1], ht + 1]) - ul[1]])
    oldX = np.array([np.max([1, ul[0]]), np.min([br[0], wd + 1]) - 1])
    oldY = np.array([np.max([1, ul[1]]), np.min([br[1], ht + 1]) - 1])

    newX -= 1
    newY -= 1
    oldX -= 1
    oldY -= 1

    if newDim.shape[0] > 2:
        newImg[: newDim[0] + 1, newY[0]: newY[1] + 1, newX[0] : newX[1] + 1] = \
            img[: newDim[0] + 1, oldY[0]: oldY[1] + 1, oldX[0] : oldX[1] + 1]
    else:
        newImg[newY[0]: newY[1] + 1, newX[0] : newX[1] + 1] = \
            img[oldY[0]: oldY[1] + 1, oldX[0] : oldX[1] + 1]


    if rot != 0:
        if newDim.shape[0] > 2:
            req_img = np.transpose(newImg, [1, 2, 0])
            rot_img = Image.fromarray(req_img.astype(np.uint8)).rotate(rot * np.pi / 180)
            rot_img = np.transpose(np.array(rot_img), [2, 0, 1])
            newImg = rot_img[:newDim[0] + 1, pad : newDim[1] - pad + 1, pad : newDim[2] - pad + 1]
        else:
            rot_img = Image.fromarray(newImg.astype(np.uint8)).rotate(rot * np.pi / 180)

            newImg = np.array(rot_img)[pad : newDim[1] - pad + 1, pad : newDim[2] - pad + 1]


    if newImg.shape[0] > 2:
        req_img = np.transpose(newImg, [1, 2, 0])
        rescale_img = Image.fromarray(req_img.astype(np.uint8)).resize((res, res))
        rescale_img = np.transpose(np.array(rescale_img), [2, 0, 1])
    else:
        rescale_img = Image.fromarray(newImg.astype(np.uint8)).resize((res, res))
    return np.array(rescale_img)


def load_dataset(file_type = "train",
                 file_path_format = r"E:\Temp\h36m_annot\h36m\annot\{}.h5",
                 image_file_format = r"E:\Temp\h36m_annot\h36m\annot\{}_images.txt"):
    assert file_type in ["train", "valid"]

    file_path = file_path_format.format(file_type)
    file = h5py.File(file_path, "r")
    annot = dict()
    tags = ["center", "scale"]
    for tag in tags:
        annot[tag] = np.array(file[tag])
    annot["nsamples"] = annot["center"].shape[0]
    annot["part"] = np.array(file["part"])
    annot["S"] = np.array(file["S"])
    file.close()

    image_file = image_file_format.format(file_type)
    annot["images"] = []
    toIdx = dict()
    idx = 0
    with open(image_file, "r") as f:
        while True:
            line = f.readline().strip()
            if not line:
                break
            annot["images"].append(line)
            if not toIdx.get(line):
                toIdx[line] = []
            toIdx[line].append(idx)
            idx += 1
    annot["imageToIdxs"] = toIdx

    # not have multi count

    print("toIdx count :")
    print(np.unique(list(map(len ,toIdx.values())), return_counts=True))

    print("n_samples : {}".format(annot["nsamples"]))
    print("images: num : {}".format(len(annot["images"])))

    return annot

        将上述代码进行Cython编译生成data_loader_util

       尽管单个的速度有所提升,但是并不能完全弥补生成数据的劣势,故在使用上述编译后的data_loader_util.pyd时使用joblib进行并行处理如下:

from joblib import Parallel, delayed
import pickle
import glob
from random import sample
from sklearn.utils import shuffle
from itertools import combinations
from data_loader_util import *

def main_data_sample(data_type = "valid" ,debug = False):
    assert data_type in ["train", "valid"]
    train_annot = load_dataset(data_type)
    nsamples = train_annot["nsamples"]

    print("dataset and counts :")
    print(np.unique(list(map(lambda x: x.split("_")[0], train_annot["images"])), return_counts=True))

    def produce_single_img_and_heatmap(i):
        image_path = train_annot["images"][i]
        s_head = image_path.split("_")[0]
        im = Image.open(r"E:\Temp\{}\{}~\{}".format(s_head, s_head, image_path))
        im_mask = Image.fromarray(np.zeros_like(np.array(im)[:, :, 0]))
        center = train_annot["center"][i]
        scale = train_annot["scale"][i]
        im = np.array(np.transpose(im, [2, 0, 1]))

        # point category idx [1, ..., num_joints]
        im_mask = np.array(im_mask)
        idx = train_annot["imageToIdxs"][image_path][0]
        # [17, 2]
        xy_17 = train_annot["part"][idx]

        im_mask_list = []
        for point_idx, xy in enumerate(xy_17.tolist()):
            gaussian_mask = CenterGaussianHeatMap_cc(im_mask.shape[0], im_mask.shape[1], xy[0], xy[1],
                                                     10.0)
            im_mask_list.append(gaussian_mask)

        inp = crop(im, center, scale, 0, 256)
        def im_mask_process(im_mask):
            im_mask = im_mask[np.newaxis, :, :]
            im_mask = np.concatenate([im_mask, im_mask, im_mask], axis=0)

            im_add_mask = im_mask * 100.0

            inp_mask = crop(im_add_mask, center, scale, 0, 256)
            inp_mask = np.transpose(inp_mask, [1, 2, 0]).astype(np.uint8)
            return np.array(Image.fromarray(inp_mask[:, :, 0]).resize((64, 64)))[:, :, np.newaxis]

        im_mask_conclusion = np.concatenate(list(map(im_mask_process, im_mask_list)), axis=-1)
        input_img = np.transpose(inp, [1, 2, 0]).astype(np.uint8)

        if debug:
            im_mask_sum = np.sum(im_mask_conclusion, axis=-1)
            Image.fromarray(im_mask_sum).show()
            Image.fromarray(input_img).show()
            print(im_mask_conclusion.shape)
            print(input_img.shape)

        return (im_mask_conclusion ,input_img)

    def result_gen(i_range):
        i_list = list(i_range)
        im_mask_conclusion_list, input_img_list = [], []
        for i in i_list:
            im_mask_conclusion ,input_img = produce_single_img_and_heatmap(i)
            im_mask_conclusion_list.append(im_mask_conclusion)
            input_img_list.append(input_img)

        # [len(i_list), 17]
        input_deepth = train_annot["S"][i_list, :, -1]
        return (np.stack(im_mask_conclusion_list, axis=0), np.stack(input_img_list, axis=0), input_deepth)

    nest_i_list = []
    inner_list_size = 10
    for i in range(nsamples):
        if i % inner_list_size == 0:
            nest_i_list.append([])
        nest_i_list[-1].append(i)

    gap = 30
    gap_list = [i for i in range(0, len(nest_i_list), gap)]
    for i in range(len(gap_list) - 1):
        start = gap_list[i]
        end = gap_list[i + 1]
        result = Parallel(n_jobs=12)(delayed(result_gen)(i,) for i in map(lambda x: nest_i_list[x], range(start, end)))

        with open(r"E:\Temp\pkl_files\annot_img_{}.pkl".format(i), "wb") as f:
            pickle.dump(result, f)
        print("dump {} end".format(i))


def batch_data_loader(type = "train", batch_num = 30, is_shuffle = True,
                      num_joints = 17, equal_tolerance = 50):
    assert type in ["train", "valid"]

    ordered_comb_idx_list = sorted(list(map(list ,combinations(range(num_joints), 2))))
    def map_17_to_comb(input_17, equal_tolerance = equal_tolerance):
        # input_17 [17,]
        req = []
        for i, j in ordered_comb_idx_list:
            zi, zj = input_17[i], input_17[j]
            zi_sub_zj = zi - zj
            if np.abs(zi_sub_zj) < equal_tolerance:
                req.append(0)
            elif zi_sub_zj > 0:
                req.append(-1)
            else:
                req.append(1)
        # [comb_num,]
        return np.asarray(req, dtype=np.int32)

    def read_pkl_file(file_path):
        with open(file_path, "rb") as f:
            return pickle.load(f)

    all_pkl_files = glob.glob(r"E:\Temp\{}_pkl_files\*".format(type))
    if is_shuffle:
        all_pkl_files = sample(all_pkl_files, len(all_pkl_files))

    input_img = np.zeros(shape=[batch_num, 256, 256, 3], dtype=np.float32)
    heat_map = np.zeros(shape=[batch_num, 64, 64, 17], dtype=np.float32)
    comb_condition_input = np.zeros(shape=[batch_num, len(ordered_comb_idx_list)], dtype=np.int32)
    start_idx = 0
    now_file_list = None
    now_file = None
    now_file_0, now_file_1, now_file_2 = [None] * 3

    while True:
        if now_file_list is None:
            if all_pkl_files:
                req_file = all_pkl_files.pop()
                print("load file :{}".format(req_file))
                now_file_list = read_pkl_file(req_file)
                if is_shuffle:
                    now_file_list = sample(now_file_list, len(now_file_list))
            else:
                print("all {} file read end, will return".format(type))
                yield None
                return
        if now_file is None:
            now_file = now_file_list.pop()
        if now_file_0 is None:
            now_file_0 = now_file[0]
        if now_file_1 is None:
            now_file_1 = now_file[1]
        if now_file_2 is None:
            # [17,]
            now_file_2 = now_file[2]
            now_file = None

        if isinstance(now_file_0, np.ndarray):
            now_file_0 = now_file_0.tolist()
        if isinstance(now_file_1, np.ndarray):
            now_file_1 = now_file_1.tolist()
        if isinstance(now_file_2, np.ndarray):
            now_file_2 = now_file_2.tolist()

        input_img[start_idx] = np.asarray(now_file_1.pop(), dtype=np.float32)
        heat_map[start_idx] = np.asarray(now_file_0.pop(), dtype=np.float32)
        comb_condition_input[start_idx] = map_17_to_comb(now_file_2.pop())

        start_idx += 1
        if not now_file_list:
            now_file_list = None
        if not now_file:
            now_file = None
        if not now_file_0:
            now_file_0 = None
        if not now_file_1:
            now_file_1 = None
        if not now_file_2:
            now_file_2 = None

        if start_idx == batch_num:
            if is_shuffle:
                input_img, heat_map, comb_condition_input = shuffle(input_img, heat_map, comb_condition_input)

            heat_map = heat_map / 99.0
            yield (input_img, heat_map, comb_condition_input)
            input_img = np.zeros(shape=[batch_num, 256, 256, 3], dtype=np.float32)
            heat_map = np.zeros(shape=[batch_num, 64, 64, 17], dtype=np.float32)
            comb_condition_input = np.zeros(shape=[batch_num, len(ordered_comb_idx_list)], dtype=np.int32)
            start_idx = 0

if __name__ == "__main__":
    pass

       main_data_sample进行并行数据生成,batch_data_loader进行数据导出。

       模型训练及数据导出代码:

import tensorflow as tf

def conv2d(inputs, filters, kernel_size, strides = (2, 2), name = None,
           add_max_pooling = True, is_training = tf.constant(True)):
    output = tf.layers.conv2d(inputs=inputs, filters = filters, kernel_size=kernel_size,
                              strides=strides, padding="SAME", name = name,
                              )
    if add_max_pooling:
        output = tf.layers.max_pooling2d(inputs=output, strides=strides, padding="SAME",
                                         pool_size=kernel_size, name="{}_max_pool_2d".format(name))

    return tf.nn.leaky_relu(output)


def residual(inputs, out_channels, name = None, is_training = tf.constant(True),
             add_max_pooling = False):
    conv2d_output = conv2d(inputs = inputs, filters=out_channels, kernel_size=(3, 3),
                           strides=(1, 1), name ="{}_conv".format(name),
                           is_training=is_training, add_max_pooling=False)
    identity_output = tf.layers.conv2d(inputs=inputs, filters = out_channels, kernel_size=(1, 1),
                                       strides=(1, 1), padding="SAME", name = "{}_identity".format(name),
                                       )
    output = conv2d_output + identity_output
    if add_max_pooling:
        output = tf.layers.max_pooling2d(inputs=output, strides=(2, 2), padding="SAME",
                                         pool_size=(3, 3), name="{}_max_pool_2d".format(name))

    output = tf.nn.leaky_relu(tf.layers.batch_normalization(output, training=tf.constant(True),
                                                            name="{}_batch_normalizetion".format(name)))

    return tf.nn.dropout(output, keep_prob= 1.0 - tf.cast(is_training, tf.float32) * 0.05)


if __name__ == "__main__":
    pass
import tensorflow as tf
from model.model_utils_final import conv2d, residual
from functools import reduce
from model.data_loader_final import batch_data_loader
import os
from itertools import combinations
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from PIL import Image
from uuid import uuid1

class OD3DPose(object):
    def __init__(self, num_channels = 3, height = 256, width = 256,
                 num_joints = 17, heatmap_size = 64,
                 stage_num = 3, batch_size = 2, lambda_val = 10.0):
        self.lambda_val = lambda_val
        self.num_joints = num_joints
        self.batch_size = batch_size
        # stage_num indicate the total hourglass num . construct heatmap loss by
        # for loop
        self.stage_num = stage_num

        self.input_img = tf.placeholder(tf.float32, [None, height, width, num_channels])
        self.input_heatmap = tf.placeholder(tf.float32, [None, heatmap_size,
                                                         heatmap_size, num_joints])

        self.is_training = tf.placeholder(tf.bool, [])

        ######### init comb indices
        self.ordered_comb_idx_list = sorted(list(map(list ,combinations(range(num_joints), 2))))
        self.comb_condition_input = tf.placeholder(tf.int32, [None, len(self.ordered_comb_idx_list)])

        self.heatmap_producer_list = []
        self.batch_z_list = []
        self.model_construct()
        self.opt_construct()
        print("opt_construct end")

    def conv_downsample_256_to_64(self, output_channels = 64):
        input_down_to_64_conv = conv2d(inputs=self.input_img, filters=output_channels, strides=(2, 2),
                                       name="conv_256_to_64", add_max_pooling=True, kernel_size=(7, 7),
                                       is_training=self.is_training)
        return input_down_to_64_conv

    def single_hourglass_layer(self, input, layer_num = 0):
        # input [batch ,64, 64, 64] output [batch, 64, 64, 64]
        with tf.variable_scope("hourglass_layer_{}".format(layer_num)):
            # [batch, 64, 64, 64]
            input_middle = residual(inputs=input, out_channels = 64, name = "input_middle",
                                    add_max_pooling=False, is_training=self.is_training)

            # [batch, 32, 32, 128]
            residual_1 = residual(inputs=input, out_channels = 128, name= "residual_1"
                                  ,add_max_pooling=True, is_training=self.is_training)
            # [batch, 32, 32, 128]
            residual_1_middle = residual(inputs=residual_1, out_channels=128, name="residual_1_middle",
                                         add_max_pooling=False, is_training=self.is_training)

            # [batch, 16, 16, 256]
            residual_2 = residual(inputs=residual_1, out_channels = 256, name= "residual_2"
                                  ,add_max_pooling=True, is_training=self.is_training)
            # [batch, 16, 16, 256]
            residual_2_middle = residual(inputs=residual_2, out_channels=256, name="residual_2_middle",
                                         add_max_pooling=False, is_training=self.is_training)

            # [batch, 8, 8, 512]
            residual_3 = residual(inputs=residual_2, out_channels = 512, name= "residual_3"
                                  ,add_max_pooling=True, is_training=self.is_training)
            # [batch, 8, 8, 512]
            residual_3_middle = residual(inputs=residual_3, out_channels=512, name="residual_3_middle",
                                         add_max_pooling=False, is_training=self.is_training)

            # this 3 layers shape [batch_size ,4, 4, 1028]
            residual_4 = residual(inputs=residual_3, out_channels = 1028, name= "residual_4"
                                  ,add_max_pooling=True, is_training=self.is_training)
            residual_5 = residual(inputs=residual_4, out_channels = 1028, name= "residual_5"
                                  ,add_max_pooling=False, is_training=self.is_training)
            residual_6 = residual(inputs=residual_5, out_channels = 1028, name= "residual_6"
                                  ,add_max_pooling=False, is_training=self.is_training)

            # [batch, 8, 8, 1024]
            up_1 = tf.image.resize_nearest_neighbor(images = residual_6, size = (8, 8),
                                                    name="up_1")
            # [batch, 8, 8, 512]
            residual_up_1 = residual(inputs=up_1, out_channels=512, name="residual_up_1",
                                     add_max_pooling=False, is_training=self.is_training)
            before_up_2 = tf.add(residual_3_middle, residual_up_1, name="before_up2")

            # [batch, 16, 16, 512]
            up_2 = tf.image.resize_nearest_neighbor(images = before_up_2, size = (16, 16),
                                                    name="up_2")
            # [batch, 16, 16, 256]
            residual_up_2 = residual(inputs=up_2, out_channels=256, name="residual_up_2",
                                     add_max_pooling=False, is_training=self.is_training)
            before_up3 = tf.add(residual_2_middle, residual_up_2, name="before_up3")

            # [batch, 32, 32, 256]
            up_3 = tf.image.resize_nearest_neighbor(images = before_up3, size = (32, 32),
                                                    name="up_3")
            # [batch, 32, 32, 128]
            residual_up_3 = residual(inputs=up_3, out_channels=128, name="residual_up_3",
                                     add_max_pooling=False, is_training=self.is_training)
            before_up_4 = tf.add(residual_1_middle, residual_up_3, name="before_up_4")

            # [batch, 64, 64, 128]
            up_4 = tf.image.resize_nearest_neighbor(images = before_up_4, size = (64, 64),
                                                    name="up_4")
            # [batch, 64, 64, 64]
            residual_up_4 = residual(inputs=up_4, out_channels=64, name="residual_up_4",
                                     add_max_pooling=False, is_training=self.is_training)
            output = tf.add(input_middle, residual_up_4, name="before_up2")
            return output

    def produce_single_heatmap_layer(self, input, layer_num = 0):
        # input [batch_num ,64, 64, 64]
        with tf.variable_scope("single_heatmap_layer_{}".format(layer_num)):
            # [batch_num, 64, 64, 64]
            hourglass_output = self.single_hourglass_layer(input, layer_num=layer_num)
            # [batch, 64, 64, 64]
            res_1 = residual(inputs=hourglass_output, out_channels=64, name="res_1",
                             add_max_pooling=False, is_training=self.is_training)
            # [batch, 64, 64, num_jojnts]
            heatmap_producer = conv2d(inputs=res_1, filters = self.num_joints, kernel_size = (1, 1), strides = (1, 1), name =
            "heatmap_producer_{}".format(layer_num),
                                      add_max_pooling = False, is_training = self.is_training)

            res_1_flatten = tf.reshape(res_1, [-1, 64 * 64 * 64], name="res_1_flatten")
            batch_z_1 = tf.layers.dense(inputs=res_1_flatten, units=100, name="batch_z_1",
                                        activation=tf.nn.sigmoid)
            batch_z = tf.layers.dense(inputs=batch_z_1, units=self.num_joints, name="batch_z")

            # [batch, 64, 64, 64]
            res_2 = residual(inputs=res_1, out_channels=64, name="res_2",
                             add_max_pooling=False, is_training=self.is_training)
            # [batch, 64, 64, 64]
            res_3 = residual(inputs=heatmap_producer, out_channels=64, name="res_3",
                             add_max_pooling=False, is_training=self.is_training)

            # [batch, 64, 64, 64]
            output = tf.add_n([input, res_2, res_3], name="output")
            heatmap_producer = tf.nn.sigmoid(heatmap_producer)
            return output, heatmap_producer, batch_z

    def transform_batch_z_to_cmob_3_tensor(self, batch_z):
        # [cn2, 2]
        ordered_comb_idx_array = np.asarray(self.ordered_comb_idx_list).astype(np.int32)
        batch_idx_range = range(self.batch_size)

        def retrieve_idxes(comb_idx_array):
            req_idx = []
            for batch_idx in batch_idx_range:
                req_idx.append([])
                for comb_idx in comb_idx_array:
                    req_idx[-1].append([batch_idx, comb_idx])
            # [batch, comb_idx, 2]
            return np.asarray(req_idx).astype(np.int32)

        # [batch, comb_idx]
        first_part = tf.gather_nd(batch_z ,retrieve_idxes(ordered_comb_idx_array[:, 0]))
        second_part = tf.gather_nd(batch_z ,retrieve_idxes(ordered_comb_idx_array[:, 1]))

        # [batch, comb_idx, 3]
        return tf.concat([tf.expand_dims(first_part, -1), tf.expand_dims(second_part, -1), tf.cast(tf.expand_dims(self.comb_condition_input, -1), tf.float32)], axis=-1)

    def batch_z_loss_construct(self, batch_comb_3):
        # batch_comb_3 [batch, comb_idx, 3]
        # [batch * comb_idx, 3]
        flatten_3_input = tf.reshape(batch_comb_3 ,[-1, 3])

        def z_loss_construct(t3):
            # t3 [3]
            category_label = t3[2]
            zi, zj = t3[0], t3[1]
            z_loss = tf.cond(tf.equal(category_label, tf.constant(1.0)),
                             true_fn=lambda : tf.log(1 + tf.exp(zi - zj)),
                             false_fn=lambda : tf.cond(tf.equal(category_label, tf.constant(-1.0)),
                                                       true_fn=lambda : tf.log(1 + tf.exp(zj - zi)),
                                                       false_fn=lambda : tf.reduce_sum(tf.pow(zi - zj, 2))
                                                       )
                             )
            return z_loss

        return tf.reduce_sum(tf.map_fn(z_loss_construct, flatten_3_input), name="z_loss")

    # pred corr by convnet as Stacked Hourglass Networks for Human Pose Estimation
    def model_construct(self):
        # [batch, 64, 64, 64]
        img_before_hourglass = self.conv_downsample_256_to_64()
        assert int(img_before_hourglass.get_shape()[1]) == int(img_before_hourglass.get_shape()[2]) \
               == 64
        input = img_before_hourglass
        for stage_idx in range(self.stage_num):
            output, heatmap_producer, batch_z = self.produce_single_heatmap_layer(input, layer_num=stage_idx)
            self.heatmap_producer_list.append(heatmap_producer)
            self.batch_z_list.append(batch_z)

        # should retrieve heatmap from heatmap_producer_list because nn.sigmoid rescale
        self.total_heatmap_loss = reduce(lambda x, y: x + y, map(lambda heatmap_producer: tf.nn.l2_loss(heatmap_producer - self.input_heatmap), self.heatmap_producer_list))
        self.total_rank_loss = reduce(lambda x, y: x + y, map(lambda batch_z: self.batch_z_loss_construct(self.transform_batch_z_to_cmob_3_tensor(batch_z)), self.batch_z_list))
        self.total_loss = self.total_rank_loss + self.lambda_val * self.total_heatmap_loss

    def opt_construct(self):
        self.train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(self.total_loss)

    @staticmethod
    def retrieve_t3(heatmap_producer, batch_z):
        # batch_z [batch, 17] heatmap_producer [batch, 64, 64, 17] np.ndarray ext
        def max64x64_xy(input_64x64):
            input_64x64 = np.asarray(input_64x64)
            flatten_input = input_64x64.reshape([-1])
            max_idx = flatten_input.argmax()
            x, y = divmod(max_idx, 64)
            return [y, x]

        # [batch * 17, 2]
        flatten_idx_conclusion = list(map(max64x64_xy,  np.transpose(heatmap_producer, [0, 3, 1, 2]).reshape([-1, 64, 64]).tolist()))
        # [batch, 17, 2]
        indice_conclusion = np.asarray(flatten_idx_conclusion, dtype=np.int32).reshape([-1, 17, 2])
        # [batch, 17, 3]
        batch_t3 = np.concatenate([indice_conclusion, batch_z[:, :, np.newaxis]], axis=-1)
        return batch_t3

    @staticmethod
    def visualize_conclusion(img_input , heatmap_input , pic_file_name, show = False):
        # img_input heatmap_input [64, 64, 3]
        # visualize single conclusion save to local pics
        pic_file_name_list = pic_file_name.split("\\")

        def draw_img(img_input, head = "img_"):
            assert head in ["img_", "heat_"]
            if head == "heat_":
                img_input = np.sum(img_input, axis=-1, keepdims=True) * 100.0

            if head == "heat_":
                img_input = np.squeeze(img_input)

            img = Image.fromarray(img_input.astype(np.uint8)).resize((256,  256))
            img.save("{}".format("\\".join(pic_file_name_list[:-1] + [head+ pic_file_name_list[-1]])))
            if show:
                img.show()

        draw_img(img_input, "img_")
        draw_img(heatmap_input, "heat_")

    @staticmethod
    def visualize_t3(t3, pic_file_name, show = False):
        pic_file_name_list = pic_file_name.split("\\")
        def draw_3d(t3):
            # t3 [17, 3] width, height , depth
            x = t3[:, 0]
            y = t3[:, 1]
            z = t3[:, 2]

            req_x, req_y, req_z = x, z ,y
            req_z = req_z.max() - req_z
            # new a figure and set it into 3d
            fig = plt.figure()
            ax = Axes3D(fig)

            # set figure information
            ax.set_title("3D_Curve")
            ax.set_xlabel("x")
            ax.set_ylabel("y")
            ax.set_zlabel("z")

            # draw the figure, the color is r = read
            figure1 = ax.plot(req_x[[13, 12, 11]], req_y[[13, 12, 11]], req_z[[13, 12, 11]], c='b')
            figure2 = ax.plot(req_x[[14, 15, 16]], req_y[[14, 15, 16]], req_z[[14, 15, 16]], c='g')
            figure3 = ax.plot(req_x[[1, 2, 3]], req_y[[1, 2, 3]], req_z[[1, 2, 3]], c='r')
            figure4 = ax.plot(req_x[[4, 5, 6]], req_y[[4, 5, 6]], req_z[[4, 5, 6]], c='c')
            figure5 = ax.plot(req_x[[10, 9, 8, 7, 0]], req_y[[10, 9, 8, 7, 0]], req_z[[10, 9, 8, 7, 0]], c='m')

            ax.view_init(0, -90)
            plt.savefig("{}".format("\\".join(pic_file_name_list[:-1] + ["3d_" + pic_file_name_list[-1]])))
            if show:
                plt.show()

        draw_3d(t3)


    @staticmethod
    def train():
        batch_size = 3

        model = OD3DPose(batch_size=batch_size)
        saver = tf.train.Saver()
        train_gen = batch_data_loader(type="train", batch_num=batch_size)
        valid_gen = batch_data_loader(type="valid", batch_num=batch_size)

        step = 0
        epoch = 0

        with tf.Session() as sess:
            if os.path.exists(r"C:\Coding\Python\OrdinalDepthPose\od3d.meta"):
                saver.restore(sess, save_path=r"C:\Coding\Python\OrdinalDepthPose\od3d")
                print("load exist model")
            else:
                sess.run(tf.global_variables_initializer())
                print("init new model")

            while True:
                train_data = train_gen.__next__()
                if train_data is None:
                    print("one epoch end")
                    epoch += 1
                    train_gen = batch_data_loader(type="train", batch_num=3)
                    train_data = train_gen.__next__()

                input_img, heat_map, comb_condition_input = train_data
                if step % 50 != 0:
                    _, loss, total_heatmap_loss, total_rank_loss = sess.run([model.train_op, model.total_loss,
                                                                             model.total_heatmap_loss, model.total_rank_loss,
                                                                             ],
                                                                            feed_dict={
                                                                                model.input_img: input_img,
                                                                                model.input_heatmap: heat_map,
                                                                                model.is_training: True,
                                                                                model.comb_condition_input: comb_condition_input,
                                                                            })
                else:
                    _, loss, total_heatmap_loss, total_rank_loss, \
                    heatmap_producer_list, batch_z_list = sess.run([model.train_op, model.total_loss,
                                                                    model.total_heatmap_loss, model.total_rank_loss,
                                                                    model.heatmap_producer_list,
                                                                    model.batch_z_list,
                                                                    ],
                                                                   feed_dict={
                                                                       model.input_img: input_img,
                                                                       model.input_heatmap: heat_map,
                                                                       model.is_training: True,
                                                                       model.comb_condition_input: comb_condition_input,
                                                                   })
                    print("train epoch :{} total_heatmap_loss loss : {} total_rank_loss : {}".format(epoch ,total_heatmap_loss, total_rank_loss))
                    batch_idx = 0
                    OD3DPose.visualize_conclusion(input_img[batch_idx], heat_map[batch_idx], r"E:\Temp\train pics\{}.jpg".format(uuid1()))
                    for stage_num in range(3):
                        heatmap_producer, batch_z = heatmap_producer_list[stage_num], batch_z_list[stage_num]
                        batch_t3 = OD3DPose.retrieve_t3(heatmap_producer, batch_z)
                        OD3DPose.visualize_t3(batch_t3[batch_idx], r"E:\Temp\train pics\{}.jpg".format(uuid1()))

                if step % 50 == 0:
                    valid_data = valid_gen.__next__()
                    if train_data is None:
                        print("valid epoch end")
                        valid_gen = batch_data_loader(type="valid", batch_num=batch_size)
                        valid_data = valid_gen.__next__()

                    input_img, heat_map, comb_condition_input = valid_data

                    loss, total_heatmap_loss, total_rank_loss, \
                    heatmap_producer_list, batch_z_list = sess.run([model.total_loss,
                                                                    model.total_heatmap_loss, model.total_rank_loss,
                                                                    model.heatmap_producer_list,
                                                                    model.batch_z_list,],
                                                                   feed_dict={
                                                                       model.input_img: input_img,
                                                                       model.input_heatmap: heat_map,
                                                                       model.is_training: False,
                                                                       model.comb_condition_input: comb_condition_input,
                                                                   })
                    print("valid epoch :{} total_heatmap_loss loss : {} total_rank_loss : {}".format(epoch ,total_heatmap_loss, total_rank_loss) + "-" * 100)
                    batch_idx = 0
                    OD3DPose.visualize_conclusion(input_img[batch_idx], heat_map[batch_idx], r"E:\Temp\valid pics\{}.jpg".format(uuid1()))
                    for stage_num in range(3):
                        heatmap_producer, batch_z = heatmap_producer_list[stage_num], batch_z_list[stage_num]
                        batch_t3 = OD3DPose.retrieve_t3(heatmap_producer, batch_z)
                        OD3DPose.visualize_t3(batch_t3[batch_idx], r"E:\Temp\valid pics\{}.jpg".format(uuid1()))

                if step % 300 == 0:
                    saver.save(sess, save_path=r"C:\Coding\Python\OrdinalDepthPose\od3d")
                    print("have save")

                step += 1

if __name__ == "__main__":
    OD3DPose.train()

        在训练结果的可视化上使用了matplotlib 3D可视化后对2D投影的方法。

       下面是一些valid集合3D结果及对应的heatmap、原图的例子:(这里的颜色跟论文中不同)

    

    

猜你喜欢

转载自blog.csdn.net/sinat_30665603/article/details/81506229