yolov3的pytorch版本保存自定义数据集训练好的权重,并载入自己的模型

多次试验终于测出来了!!很高兴,结果截图:
在这里插入图片描述
数据集是来自网上的,代码原型是github一个大概五千多star的pytorch-yolov3,但原代码并没有载入自己的模型进行训测试阶段,然后parser参数一直不明白,导致试了多次。

其中的要点:
1.初始化权重的修改
2。载入保存好的权重。(模式选择要正确)
3。格式要正确(比如什么地方加-- 什么地方加/)
4。保存训练的模型

传入参数部分:

 parser = argparse.ArgumentParser()
    parser.add_argument("--image_folder", type=str, default="data/samples", help="path to dataset")
    parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
    parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
    parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
    parser.add_argument("--conf_thres", type=float, default=0.8, help="object confidence threshold")
    parser.add_argument("--nms_thres", type=float, default=0.4, help="iou thresshold for non-maximum suppression")
    parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
    parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
    parser.add_argument("--checkpoint_model", type=str, help="path to checkpoint model")
    opt = parser.parse_args()
    print(opt)

其中("–weights_path"才是载入自己的模型部分,一开始以为是"–checkpoint_model",导致怎么都不对。
记录如图:
在这里插入图片描述可以看到根本没有输出标签,即并没有传自己的模型,而用的默认模型,默认模型不是自己训练的数据当然不可能识别出来。

正确传参的代码:

python3 detect.py --image_folder data/custom/dd --model_def config/yolov3-custom.cfg --class_path data/custom/classes.names --checkpoint_model checkpoints/yolov3_ckpt_99.pth --weights_path checkpoints/yolov3_ckpt_99.pth

在这里插入图片描述
此时标签已经输出。

其中detect.py中载入模型的语句:

    # Set up model
    model = Darknet(opt.model_def, img_size=opt.img_size).to(device)

    if opt.weights_path.endswith(".weights"):
        # Load darknet weights
        model.load_darknet_weights(opt.weights_path)
    else:
        # Load checkpoint weights
        model.load_state_dict(torch.load(opt.weights_path))

    model.eval()  # Set in evaluation mode

if opt.weights_path.endswith(".weights"):
决定了是weights_path
train.py中保存模型:
if epoch % opt.checkpoint_interval == 0:
torch.save(model.state_dict(), f"checkpoints/yolov3_ckpt_%d.pth" % epoch)

之前用resnet18训练过一个模型,但是从参数看感觉yolov3网络结构比它复杂
下一步的学习:
1。如何输出目标中心坐标点(如果有多个坐标点如何迭代)
2。调整哪些参数可以得到更精确的测试结果
3。上面的代码每轮epoch都保存了模型,结果一共保存了99个,但实际运用的话需要保存map最好的那个,所以这个代码需要优化
4。怎么从摄像头读取图片让yolov3检测
5。可不可以预处理图片后增加图片数量。

猜你喜欢

转载自blog.csdn.net/qq_41358574/article/details/114816697