基于pytorch实现的图像分类代码

基于pytorch实现的图像分类源码

这个代码是干嘛的?

这个代码是基于pytorch框架实现的深度学习图像分类,主要针对各大有图像分类需求的使用者。
当然这个代码不适合大佬使用,对于大佬我建议是直接使用mmcv或者timm。
timm是我认为目前比较顶流的图像分类框架,也有很多图像分割、目标检测的源码使用timm作为backbone。
mmcv就更不用说了,就是大佬中的大佬。
当然除了大佬,我不建议各位使用timm和mmcv,不是因为他不好用,而是因为他使用难度较高,对于代码能力一般的同学,跑通估计就已经比较吃力,就不需要说根据自己的需求进行修改代码了。
因此我花了很多时间去研究bubbliiing(我是他粉丝)、yolov5、timm…比较优秀的开源框架后,编写并整合各大优秀源码到一个代码里面。
当然也希望通过这个项目,能够提升自己的编程水平和对深度学习的进一步理解,因为代码是我自己一个人进行编写和整合,虽然经过一些测试,但是没办法进行各方位测试,如果使用者遇到bug、出现报错、有不对的地方,可以通过留言、私信、邮箱([email protected])进行联系作者,咱们可以一起讨论讨论。
也希望通过这个平台,能交到更多这行的朋友,感谢各位!
源码地址:https://github.com/z1069614715/pytorch-classifier
源码使用案例:使用pytorch实现花朵分类
源码中的损失函数代码案例:pytorch代码-图像分类损失函数
源码地址中有更详细的解释,后续也会在哔哩哔哩中上传如何使用的视频。
如果这个代码帮助了你,请在博客点个赞,请在github点个star,谢谢!

为什么推荐你使用这个代码?

  • 丰富的可视化功能
    1. 训练图像可视化.
    2. 损失函数,精度,学习率迭代图像可视化.
    3. 热力图可视化.
    4. TSNE可视化.
    5. 数据集识别情况可视化.(metrice.py文件中–visual参数,开启可以自动把识别正确和错误的文件路径,类别,概率保存到csv中,方便后续分析)
    6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,accuracy,f0.5,f1,f2,auc,aupr)
    7. 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa)
  • 丰富的模型库
    1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,详细请看Model Zoo.(变形金刚系列后续更新)
    2. 目前支持的模型都是通过作者从github和torchvision整合,因此支持修改、改进模型进行实验,并不是直接调用库创建模型.
  • 丰富的训练策略
    1. 支持断点续训,只需要设定一个参数(–resume).
    2. 支持多种常见的损失函数.(目前支持PolyLoss,CrossEntropyLoss,FocalLoss)
    3. 支持一个参数即可设置类别平衡.
    4. 支持混合精度训练.(使你的机器能支持更大的batchsize)
    5. 支持知识蒸馏.
  • 丰富的数据增强策略
    1. 支持RandAugment, AutoAugment, TrivialAugmentWide, AugMix, Mixup, CutMix, CutOut, TTA等强大的数据增强.
    2. 支持添加torchvision中的数据增强.
    3. 支持添加自定义数据增强.详细看Some explanation第十四点
  • 丰富的学习率调整策略
    本程序支持学习率预热,支持预热后的自定义学习率策略.详细看Some explanation第五点
  • 支持导出各种常用推理框架模型
    目前支持导出torchscript,onnx,tensorrt推理模型.
  • 简单的安装过程
    1. 安装好pytorch, torchvision(pytorch==1.12.0+torchvision==0.13.0+)
      可以在pytorch官网找到对应的命令进行安装.
    2. pip install -r requirements.txt
  • 人性化的设定
    1. 大部分可视化数据(混淆矩阵,tsne,每个类别的指标)都会以csv或者log的格式保存到本地,方便后期美工图像.
    2. 程序大部分输出信息使用PrettyTable进行美化输出,大大增加可观性.
  • 后续更新
    后续将会更新一些使用的图像分类的tricks到这个代码里面,例如SWA,R-Drop等等。

更新日志

Model Zoo

目前支持的模型,以下模型全部都支持基于ImageNet的预训练权重。

model model_name
resnet resnet18,resnet34,resnet50,resnet101,wide_resnet50,wide_resnet101,resnext50,resnext101
resnest50,resnest101,resnest200,resnest269
shufflenet shufflenet_v2_x0_5,shufflenet_v2_x1_0
mobilenet mobilenetv2,mobilenetv3_small,mobilenetv3_large
densenet densenet121,densenet161,densenet169,densenet201
vgg vgg11,vgg11_bn,vgg13,vgg13_bn,vgg16,vgg16_bn,vgg19,vgg19_bn
efficientnet efficientnet_b0,efficientnet_b1,efficientnet_b2,efficientnet_b3,efficientnet_b4,efficientnet_b5,efficientnet_b6,efficientnet_b7
efficientnet_v2_s,efficientnet_v2_m,efficientnet_v2_l
nasnet mnasnet0_5,mnasnet1_0
vovnet vovnet39,vovnet59
convnext convnext_tiny,convnext_small,convnext_base,convnext_large,convnext_xlarge
ghostnet ghostnet
repvgg RepVGG-A0,RepVGG-A1,RepVGG-A2,RepVGG-B0,RepVGG-B1,RepVGG-B1g2,RepVGG-B1g4
RepVGG-B2,RepVGG-B2g4,RepVGG-B3,RepVGG-B3g4,RepVGG-D2se
sequencer sequencer2d_s,sequencer2d_m,sequencer2d_l
darknet darknet53,darknetaa53
cspnet cspresnet50,cspresnext50,cspdarknet53,cs3darknet_m,cs3darknet_l,cs3darknet_x,cs3darknet_focus_m,cs3darknet_focus_l
cs3sedarknet_l,cs3sedarknet_x,cs3edgenet_x,cs3se_edgenet_x
dpn dpn68,dpn68b,dpn92,dpn98,dpn107,dpn131

如果内容对你有帮助,麻烦点个赞,谢谢!

有计算机视觉合作项目可以私信作者!

猜你喜欢

转载自blog.csdn.net/qq_37706472/article/details/127698295