六、colab训练模型

损失函数、优化器、其他参数设置可根据自己的需求选择,我这里就不再赘述,也可参考该篇博文。十七、完整神经网络模型训练步骤

网络优化搭建完成、数据集也采集完成之后开始训练,本地电脑拉跨,故通过Google提供的免费服务器进行训练

一、Google云盘

①上传模型于云盘中

登录谷歌云盘
将模型压缩,上传至谷歌云盘,这里压缩成zip格数

空白处右击上传文件,选择模型压缩包,可能会有点慢,耐心耐心~~~

在这里插入图片描述

②新建一个文件夹,防止训练时网络不稳定导致GG,用于存放模型以及损失函数等信息

创建一个unet文件夹,下面再创建一个logs文件夹
在这里插入图片描述

二、colab

colab是Google提供的免费服务器,GPU有Tesla T4等显卡
登录colab
在这里插入图片描述

这里需要滤清关系,Google云盘相当于百度网盘,只不过是Google家的,colab是Google提供的免费服务器,因为这个服务器只要一关闭,或者断网,再次连接的时候里面的东西就会全部释放,没法保存,故需要结合Google云盘来使用。即先将模型上Google云盘,然后在colab中关联云盘,再将云盘中的模型给copy到colab中,解压,然后在colab中训练模型,每个epoch需要保存一下模型(至于咋样保存看自己情况而定),把模型保存结果与云盘中的一个文件夹(这里示范用的是logs文件夹)进行软连接,之后就不怕colab因网络波动而导致释放资源了。训练好的模型就可以通过Google云盘中的文件夹(logs)进行下载了。

①关联Google云盘

选择GPU
修改---笔记本设置---GPU
在这里插入图片描述
!nvidia-smi 查看GPU版本
在这里插入图片描述

关联Google云盘
在这里插入图片描述
会出现代码如下:
这里的content可以理解为colab服务器给你分配到根目录,drive也可以自定义,也是你的云盘存放的位置

from google.colab import drive
drive.mount('/content/drive')

点击该代码段,Shift+Enter执行
在这里插入图片描述
之后选择登录你的Google云盘账号即可
在这里插入图片描述
刷新一下,就可以看到Google云盘已经关联进来了
在这里插入图片描述
在这里插入图片描述

②将Google云盘中的模型copy到colab中

!pwd 查看下当前所在路径
在这里插入图片描述
cp表示copy,./表示当前路径,即/content
/content/drive/MyDrive/unet-pytorch.zip copy到./

!cp /content/drive/MyDrive/unet-pytorch.zip ./

在这里插入图片描述

③解压模型

因为是zip压缩包格式,故使用unzip命令解压,不同格式解压命令可自行百度
!pwd可知目前所处路径为/content
将模型压缩包解压到当前目录下

!unzip ./unet-pytorch.zip -d ./

在这里插入图片描述

④删除logs文件夹

进入模型中

%cd /content/unet-pytorch/

查看当前路径
!pwd
在这里插入图片描述
删除模型中得logs文件夹

!rm -rf logs

⑤关联Google云盘中的logs文件夹

这个意思是,将Google云盘的logs文件夹/content/drvie/MyDrive/logs,和当前路径下(/content/unet-pytorch)logs文件夹关联,若不存在则创建

!ln -s /content/drive/MyDrive/unet/logs logs

第一个参数路径是存在的,第二个参数路径是虚拟的,也就是第二个是第一个的映射

⑥运行训练文件

!python train.py

在这里插入图片描述

三、模型文件

每次训练关联的模型都会存放到Google云盘中,到时候可以右击下载即可
在这里插入图片描述

四、总结

在云盘中创建unet文件夹,其下面再创建logs文件夹
在这里插入图片描述
模型存放位置
在这里插入图片描述

from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/unet-pytorch.zip ./
!unzip ./unet-pytorch.zip -d ./
%cd /content/unet-pytorch/
!rm -rf logs
!ln -s /content/drive/MyDrive/unet/logs /content/unet-pytorch/logs
!python train.py

猜你喜欢

转载自blog.csdn.net/qq_41264055/article/details/127100547