从零搭建运行Pytorch版pointNet++模型全流程及个性化数据集训练测试可视化
本次采用的是Pytorch版的pointNet++模型
服务器环境是Ubuntu18/python3.7/cuda11.4/cudnn8.2/torch==1.12.1+cu113
Pytorch版pointNet++参考:PointNet 和 PointNet++ pytorch版本 复现 modelnet40和【三维目标分类 】PointNet详解(一)
Tensorflow版pointNet++可以参考从零搭建运行Tensorflow版pointNet++模型全流程及常见问题解决
一、Ubuntu18系统安装及初始化
参考:Ubuntu18系统安装及初始化(SSH服务、网络配置)
如果安装的是Ubuntu16系统,可以执行以下命令升级到Ubuntu18:
sudo apt update
sudo apt upgrade
sudo apt dist-upgrade
sudo apt autoremove
sudo do-release-upgrade
二、源码和数据集下载
1.pointNet++源码(Pytorch版)
下载地址:https://gitcode.net/mirrors/yanx27/pointnet_pointnet2_pytorch
将下载pointnet2_pytorch-master.zip文件拷贝到服务器,然后执行unzip pointnet2_pytorch-master.zip
2.ModelNet40数据集(XYZ and normal from mesh, 10k points)
下载地址:modelnet40_normal_resampled.zip
将下载的数据集文件拷贝到pointnet2_pytorch-master程序中的data目录下(需要新建data目录),执行unzip modelnet40_normal_resampled.zip
命令解压数据集
三、搭建pointNet++所需环境(Anaconda、Cuda、cuDNN、Pytorch、Python)
**结合自身显卡硬件,根据下图进行显卡驱动、cuda、cudnn的搭配
1.显卡驱动下载安装
可以参考:Ubuntu物理机显卡驱动安装的几种方式
(1)查看适合本显卡的驱动:ubuntu-drivers devices
(2)添加驱动源:sudo add-apt-repository ppa:graphics-drivers/ppa
(3)更新软件源:sudo apt-get update
(4)安装系统推荐的显卡驱动:sudo apt-get install nvidia-driver-470
(5)安装 nvidia-cuda-toolkit 工具:sudo apt-get install nvidia-cuda-toolkit
(6)测试显卡驱动是否安装成功:nvidia-smi
2.Anaconda、Cuda、cuDNN的安装配置
Anaconda和Cuda安装配置可以参考:Ubuntu搭建Pytorch环境(Anaconda、Cuda、cuDNN、Pytorch、Python、Pycharm、Jupyter)
根据自身显卡配置,安装Cuda和cuDNN的对应版本即可,我用的cuda11.4、cudnn8.2
3.python环境及tensorflow依赖库的安装配置
(1)激活默认的虚拟环境(base环境):source activate
(2)基于python3.7创建名为torch的虚拟环境:conda create -n pytorch python=3.7
(3)切换到创建的torch虚拟环境:conda activate pytorch
(4)安装torch、torchvision和torchaudio:pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
如果安装出现问题,可以参考:Anaconda下pytorch的cpu版和gpu版安装
(5)安装tqdm:pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple --timeout=120
四、运行pointNet++(Pytorch版)
1.训练PointNet++
python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg
通过nvidia-smi
查看,大概占用4.5G显存
2.测试PointNet++
python test_classification.py --log_dir pointnet2_cls_ssg
五、pointNet++代码详解
参考PointNet系列代码复现详解(1)—PointNet分类部分