【三维几何学习】网格上低分辨率的分割结果到高分辨率的投影与可视化

网格上低分辨率的分割结果到高分辨率的投影与可视化

引言

三角网格的结构特性决定了其仅用少量三角形即可表示一个完整的3D模型。增加其分辨率可以展示更多模型的形状细节。对于网格分割来说,并不需要很多模型细节,只需要知晓其数据元素所属部分(类别)即可。
在这里插入图片描述

  • 上图分别为低分辨率分割结果、高分辨率投影结果以及Ground truth

在简化网格上进行预测,然后投影到高分辨率网格上一个可行的方案。例如:

MeshWalker1使用的的边界平滑
A Spectral Segmentation Method for Large Meshes2的feature-aware的网格简化

一、到高分辨率的投影

1.1 准确率

以面标签版本的COSEG外星人数据集为例,可参考三角网格(Triangular Mesh)分割数据集
在这里插入图片描述
简化网格上的准确率:96.94 到高分辨率网格投影:95.53
时间上也会快很多,毕竟计算高分辨率网格的输入特征较为费时

1.2 主要代码

部分代码来自3:MeshCNN
TriTransNet是对简化三角网格进行分割的网络,可替换为其它神经网络

import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPoints


def is_mesh_file(filename):
    return any(filename.endswith(extension) for extension in ['.obj', 'off'])


def fix_vertices(vs):
    z = vs[:, 2].copy()
    vs[:, 2] = vs[:, 1]
    vs[:, 1] = z
    max_range = 0
    for i in range(3):
        min_value = np.min(vs[:, i])
        max_value = np.max(vs[:, i])
        max_range = max(max_range, max_value - min_value)
        vs[:, i] -= min_value
    scale_by = max_range
    vs /= scale_by
    return vs


def get_seg_files(paths, seg_dir, seg_ext='.eseg'):
    segs = []
    for path in paths:
        segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
        assert (os.path.isfile(segfile))
        segs.append(segfile)
    return segs


def make_dataset(path):
    meshes = []
    assert os.path.isdir(path), '%s is not a valid directory' % path
    for root, _, fnames in sorted(os.walk(path)):
        for fname in fnames:
            if is_mesh_file(fname):
                path = os.path.join(root, fname)
                meshes.append(path)

    return meshes


if __name__ == '__main__':
    # 简化网格
    sim_root = '../../../datasets/face_label/coseg_aliens'
    sim_paths = make_dataset(os.path.join(sim_root, 'test'))
    # sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))
    # 原始网格
    org_root = '../../../datasets/aliens'  # '../../datasets/vases'
    org_paths = make_dataset(os.path.join(org_root, 'test'))   # shapes  or seg
    org_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')

    # 网络读取
    cfg = Config()
    cfg.class_n = 4
    cfg.mode = 'seg'
    net = TriTransNet(cfg)
    state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth')  # latest_xyz_net 95.53432
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata
    net.load_state_dict(state_dict)
    net.eval()

    # 准确率统计
    all_acc = 0
    sim_acc = 0
    are_acc = 0
    for i in range(len(sim_paths)):
        # 获取网格数据
        sim_name = sim_paths[i]
        filename, _ = os.path.splitext(sim_name)
        prefix = os.path.basename(filename)
        cache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')
        with open(cache, 'rb') as f:   # 不再计算 读取缓存
            meta = pickle.load(f)

        # 获取网格数据
        sim_mesh = meta['mesh']
        sim_label = meta['label']
        vs = fix_vertices(sim_mesh.vs)

        # 获取预测标签
        with torch.no_grad():
            face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0)  # sim_mesh.hks[0:3]
            face_features = torch.from_numpy(face_features).float().unsqueeze(0)
            out = net(face_features, [sim_mesh])
            label = out.data.max(1)[1]
            sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_num
            sim_acc += sim_correct
            # 面积
            # idex = label.eq(torch.from_numpy(sim_label).long()).numpy().reshape(-1)
            # face_area = sim_mesh.face_features[6, :]
            # sum_area = face_area.sum()
            # are_acc += face_area[idex].sum() / sum_area

        # 时间
        t = time.time()
        # 投影准备
        label = label.numpy().reshape(-1)
        BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)
        # 0.01 10 95.53432
        # 0.5 1  退化成最简单的最近邻  94.02
        kdt = cKDTree(BorderPoints_xyz)

        # 读取高分辨率网格
        org_vs, org_faces = pp3d.read_mesh(org_paths[i])
        org_vs = fix_vertices(org_vs)
        org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1
        # 原始网格中心点
        mean_vs = org_vs[org_faces]
        mean_vs = mean_vs.sum(axis=1) / 3.0

        dist, indices = kdt.query(mean_vs, workers=-1)
        # 准确率计算
        org_prolabels = BorderPoints_label[indices].reshape(-1)
        pro_cnt = np.equal(org_prolabels, org_label).sum()
        pro_acc = pro_cnt / len(org_label)
        all_acc += pro_acc
        print(filename, ':', pro_acc, ' time:', time.time()-t)
    print(all_acc / len(sim_paths))
    print(sim_acc / len(sim_paths))
    # print(are_acc / len(sim_paths))

1.3 投影核心代码

def get_faces_BorderPoints(vs, faces, labels, border_k=0.1, border_num=1):
    """
        border_k:   远离边的系数
        border_num: 每条边的边缘点数
        首先 默认简化是不会过分破坏分割边界 简化后的网格和原网格基本对齐
        1.简化后的面更大 以一个面为例 均匀采样其边界部分形成边缘点 边缘点的标签赋值为面的标签
        2.赋值原网格面标签为 距离其重心最近的简化网格边缘点标签
    """
    BorderPoints_xyz = -np.ones((len(faces) * 3 * border_num, 3), np.float64)
    BorderPoints_label = -np.ones((len(faces) * 3 * border_num, 1), np.int32)
    cnt = 0
    for face_id in range(len(faces)):
        face = faces[face_id]
        label = labels[face_id]
        for i in range(3):
            if border_num > 1:
                p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]
                for j in range(border_num):
                    center_p = p1 + (p2 - p1) / (border_num + 1) * (j + 1)
                    border_p = center_p + (p - center_p) * border_k
                    BorderPoints_xyz[cnt] = border_p
                    BorderPoints_label[cnt] = label
                    cnt = cnt + 1
            else:
                p1, p2, p = vs[face[i]], vs[face[(i + 1) % 3]], vs[face[(i + 2) % 3]]
                center_p = (p1 + p2) / 2
                border_p = center_p + (p - center_p) * border_k
                BorderPoints_xyz[cnt] = border_p
                BorderPoints_label[cnt] = label
                cnt = cnt + 1

    return BorderPoints_xyz, BorderPoints_label

二、可视化代码

减小可视化网格边的边长,查看模型细节:
在这里插入图片描述

import potpourri3d as pp3d
import numpy as np
import os
import pickle
from scipy.spatial import cKDTree
import time
import pylab as pl
import torch
from config.config import Config
from network.TriTransNet import TriTransNet
from postprocessing.mesh_project import get_faces_BorderPoints
import mpl_toolkits.mplot3d as a3
import matplotlib.colors as colors
from scipy import linalg


def rot_vs_axis_z(vs, radian, scale):
    bias = np.mean(vs)
    vs = vs - bias
    vs *= scale
    rot_matrix = linalg.expm(np.cross(np.eye(3), [0, 0, 1] / linalg.norm([0, 0, 1]) * radian))
    vs = np.dot(rot_matrix, vs.T)
    vs = vs.T + bias
    return vs


def init_ax(ax):
    # hide axis, thank to
    # https://stackoverflow.com/questions/29041326/3d-plot-with-matplotlib-hide-axes-but-keep-axis-labels/
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    # Get rid of the spines
    ax.w_xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.w_zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    # Get rid of the ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    return ax


def is_mesh_file(filename):
    return any(filename.endswith(extension) for extension in ['.obj', 'off'])


def fix_vertices(vs):
    z = vs[:, 2].copy()
    vs[:, 2] = vs[:, 1]
    vs[:, 1] = z
    max_range = 0
    for i in range(3):
        min_value = np.min(vs[:, i])
        max_value = np.max(vs[:, i])
        max_range = max(max_range, max_value - min_value)
        vs[:, i] -= min_value
    scale_by = max_range
    vs /= scale_by
    return vs


def get_seg_files(paths, seg_dir, seg_ext='.eseg'):
    segs = []
    for path in paths:
        segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
        assert (os.path.isfile(segfile))
        segs.append(segfile)
    return segs


def make_dataset(path):
    meshes = []
    assert os.path.isdir(path), '%s is not a valid directory' % path
    for root, _, fnames in sorted(os.walk(path)):
        for fname in fnames:
            if is_mesh_file(fname):
                path = os.path.join(root, fname)
                meshes.append(path)

    return meshes


if __name__ == '__main__':
    # 简化网格
    sim_root = '../../../datasets/face_label/coseg_aliens'
    sim_paths = make_dataset(os.path.join(sim_root, 'test'))
    # sim_labels = get_seg_files(sim_paths, seg_dir=os.path.join(sim_root, 'seg'))
    # 原始网格
    org_root = '../../../datasets/aliens'  # '../../datasets/vases'
    org_paths = make_dataset(os.path.join(org_root, 'test'))   # shapes  or seg
    org_labels = get_seg_files(org_paths, seg_dir=os.path.join(org_root, 'seg'), seg_ext='.seg')

    # 网络读取
    cfg = Config()
    cfg.class_n = 4
    cfg.mode = 'seg'
    net = TriTransNet(cfg)
    state_dict = torch.load('../../../results/aliens_1500/model/latest_xyz_net.pth')  # latest_xyz_net 95.53432
    if hasattr(state_dict, '_metadata'):
        del state_dict._metadata
    net.load_state_dict(state_dict)
    net.eval()

    # 准确率统计
    all_acc = 0
    sim_acc = 0
    are_acc = 0
    for i in range(len(sim_paths)):
        # 获取网格数据
        sim_name = sim_paths[i]
        filename, _ = os.path.splitext(sim_name)
        prefix = os.path.basename(filename)

        # 选择某一个网格可视化
        #if prefix != '132':
        #    continue
        if i != 3:
            continue

        cache = os.path.join('../../../results/aliens_1500/cache/', prefix + '.pkl')
        with open(cache, 'rb') as f:   # 不再计算 读取缓存
            meta = pickle.load(f)

        # 获取网格数据
        sim_mesh = meta['mesh']
        sim_label = meta['label']
        vs = fix_vertices(sim_mesh.vs)

        # 获取预测标签
        with torch.no_grad():
            face_features = np.concatenate([sim_mesh.face_features, sim_mesh.xyz], axis = 0)  # sim_mesh.hks[0:3]
            face_features = torch.from_numpy(face_features).float().unsqueeze(0)
            out = net(face_features, [sim_mesh])
            label = out.data.max(1)[1]
            sim_correct = label.eq(torch.from_numpy(sim_label).long()).sum().float() / sim_mesh.faces_num
            sim_acc += sim_correct

        # 时间
        t = time.time()
        # 投影准备
        label = label.numpy().reshape(-1)
        BorderPoints_xyz, BorderPoints_label = get_faces_BorderPoints(vs, sim_mesh.faces, label, border_k=0.01, border_num=10)
        # 0.01 10 95.53432
        # 0.5 1  退化成最简单的最近邻  94.02
        kdt = cKDTree(BorderPoints_xyz)

        # 读取高分辨率网格
        org_vs, org_faces = pp3d.read_mesh(org_paths[i])
        org_vs = fix_vertices(org_vs)
        org_label = np.loadtxt(open(org_labels[i], 'r'), dtype='float64') -1
        # 原始网格中心点
        mean_vs = org_vs[org_faces]
        mean_vs = mean_vs.sum(axis=1) / 3.0

        dist, indices = kdt.query(mean_vs, workers=-1)
        # 准确率计算
        org_prolabels = BorderPoints_label[indices].reshape(-1)
        pro_cnt = np.equal(org_prolabels, org_label).sum()
        pro_acc = pro_cnt / len(org_label)
        all_acc += pro_acc
        print(filename, ':', pro_acc, ' time:', time.time()-t)

        # 可视化
        f = pl.figure()
        ax = f.add_subplot(1, 1, 1, projection='3d')
        ax = init_ax(ax)
        r2h = lambda x: colors.rgb2hex(tuple(map(lambda y: y / 255., x)))
        f_colors = [r2h((0, 0, 255)), r2h((0, 255, 255)), r2h((255, 0, 255)), r2h((0, 255, 0))]
        vis_bias = 0.3 #

        # 简化网格
        faces_color = []
        for l in label:
            faces_color.append(f_colors[l - 1])
        vs = rot_vs_axis_z(vs, 0.95, 1)
        tri = a3.art3d.Poly3DCollection(vs[sim_mesh.faces],
                                        facecolors=faces_color,
                                        edgecolors=r2h((0, 0, 0)),
                                        linewidths=0.1,  # 0.1
                                        # linestyles='dashdot',
                                        alpha=1)
        ax.add_collection3d(tri)

        # 高分辨率网格
        org_vs = rot_vs_axis_z(org_vs, 0.95, 1)
        faces_color = []
        for l in org_prolabels.astype(int):
            faces_color.append(f_colors[l - 1])
        org_vs[:, 0] += vs[:, 0].max()/2 + vis_bias
        tri1 = a3.art3d.Poly3DCollection(org_vs[org_faces],
                                        facecolors=faces_color,
                                        edgecolors=r2h((0, 0, 0)),
                                        linewidths=0.1,
                                        # linestyles='dashdot',
                                        alpha=1)
        ax.add_collection3d(tri1)
        max_x = org_vs[:, 0].max()

        # 高分辨率网格Ground truth
        faces_color = []
        for l in org_label.astype(int):
            faces_color.append(f_colors[l - 1])
        org_vs[:, 0] += vs[:, 0].max() / 2 + vis_bias
        tri2 = a3.art3d.Poly3DCollection(org_vs[org_faces],
                                         facecolors=faces_color,
                                         edgecolors=r2h((0, 0, 0)),
                                         linewidths=0.1,
                                         # linestyles='dashdot',
                                         alpha=1)
        ax.add_collection3d(tri2)
        max_x = org_vs[:, 0].max()

        ax.auto_scale_xyz([0, max_x], [0, max_x], [0, max_x])
        ax.view_init(0, -90)
        pl.tight_layout()
        pl.savefig('corr.png', dpi=1000)  # i
        pl.show()
        break
    print(all_acc / len(sim_paths))
    print(sim_acc / len(sim_paths))


  1. MeshWalker: Deep Mesh Understanding by Random Walks ↩︎

  2. A Spectral Segmentation Method for Large Meshes ↩︎

  3. MeshCNN ↩︎

猜你喜欢

转载自blog.csdn.net/qq_38204686/article/details/129430624