基于深度学习的显著性检测用于遥感影像地物提取(MINet)

这个跟前两个一样,显著性检测貌似无法解决我的问题,准备换个方向解决我的问题了,虽然我的目的没达到,但是这个的效果确实还行的,有需要的可以好好调整一下。
使用链接:https://github.com/lartpang/MINet

原图
原图
标签
标签
预测结果
预测结果
评价结果:

acc:  0.9055214352077908
acc_cls:  0.8682510382904347
iou:  [0.88870665 0.61525859]
miou:  0.7519826202767053
fwavacc:  0.8376228680494308
class_accuracy:  0.7143424443012731
class_recall:  0.7998325458213021
accuracy:  0.9007926079195722
f1_score:  0.7546741156614227

注意这个是我默认参数跑的,iou上来就是0.6以上了,感觉效果不错,不过这个跑的有点慢。

1.数据准备
数据准备很简单,就是普通的存放方式
一级目录
一级
二级目录
这里面的文件夹名字最好和我一样,代码里是通过这个名字拼凑路径的,另外,图像和标签的名字保持一样就行。
二级
2.数据导入
这里要改的就是测试时的数据导入,训练的数据导入包含了测试和验证,我把训练时候的验证去掉了

# -*- coding: utf-8 -*-
# @Time    : 2020/7/22
# @Author  : Lart Pang
# @Email   : [email protected]
# @File    : dataloader.py
# @Project : code
# @GitHub  : https://github.com/lartpang
import os
import random
from functools import partial

import torch
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

from config import arg_config
from utils.joint_transforms import Compose, JointResize, RandomHorizontallyFlip, RandomRotate
from utils.misc import construct_print


def _get_suffix(path_list):
    ext_list = list(set([os.path.splitext(p)[1] for p in path_list]))
    if len(ext_list) != 1:
        if ".png" in ext_list:
            ext = ".png"
        elif ".jpg" in ext_list:
            ext = ".jpg"
        elif ".bmp" in ext_list:
            ext = ".bmp"
        else:
            raise NotImplementedError
        construct_print(f"数据文件夹中包含多种扩展名,这里仅使用{ext}")
    else:
        ext = ext_list[0]
    return ext


def _make_dataset(root):
    img_path = os.path.join(root, "Image")
    mask_path = os.path.join(root, "Mask")

    img_list = os.listdir(img_path)
    mask_list = os.listdir(mask_path)

    img_suffix = _get_suffix(img_list)
    mask_suffix = _get_suffix(mask_list)

    img_list = [os.path.splitext(f)[0] for f in mask_list if f.endswith(mask_suffix)]
    return [
        (
            os.path.join(img_path, img_name + img_suffix),
            os.path.join(mask_path, img_name + mask_suffix),
        )
        for img_name in img_list
    ]

def _make_dataset2(root):
    img_path = os.path.join(root, "Image")
    # mask_path = os.path.join(root, "Mask")

    img_list = os.listdir(img_path)
    # mask_list = os.listdir(mask_path)

    img_suffix = _get_suffix(img_list)
    # mask_suffix = _get_suffix(mask_list)

    # img_list = [os.path.splitext(f)[0] for f in mask_list if f.endswith(mask_suffix)]
    return [
        (
            os.path.join(img_path, img_name),
            # os.path.join(mask_path, img_name + mask_suffix),
        )
        for img_name in img_list
    ]

def _read_list_from_file(list_filepath):
    img_list = []
    with open(list_filepath, mode="r", encoding="utf-8") as openedfile:
        line = openedfile.readline()
        while line:
            img_list.append(line.split()[0])
            line = openedfile.readline()
    return img_list


def _make_dataset_from_list(list_filepath, prefix=(".png", ".png")):
    img_list = _read_list_from_file(list_filepath)
    return [
        (
            os.path.join(
                os.path.join(os.path.dirname(img_path), "Image"),   #路径拼凑的地方
                os.path.basename(img_path) + prefix[0],
            ),
            os.path.join(
                os.path.join(os.path.dirname(img_path), "Mask"),    #路径拼凑的地方
                os.path.basename(img_path) + prefix[1],
            ),
        )
        for img_path in img_list
    ]

def _make_dataset_from_list2(list_filepath, prefix=(".png", ".png")):  #用于测试数据导入,不需要标签,测试还要标签是很多时候不遇到的情况
    img_list = _read_list_from_file(list_filepath)
    return [
        (
            os.path.join(
                os.path.join(os.path.dirname(img_path), "Image"),  #路径拼凑的地方
                os.path.basename(img_path) + prefix[0],
            ),
            # os.path.join(
            #     os.path.join(os.path.dirname(img_path), "Mask"),
            #     os.path.basename(img_path) + prefix[1],
            # ),
        )
        for img_path in img_list
    ]


class ImageFolder(Dataset):
    def __init__(self, root, in_size, training, prefix, use_bigt=False):
        self.training = training
        self.use_bigt = use_bigt

        if os.path.isdir(root):
            construct_print(f"{root} is an image folder, we will test on it.")
            self.imgs = _make_dataset(root)
        elif os.path.isfile(root):
            construct_print(
                f"{root} is a list of images, we will use these paths to read the "
                f"corresponding image"
            )
            self.imgs = _make_dataset_from_list(root, prefix=prefix)
        else:
            raise NotImplementedError

        if self.training:
            self.joint_transform = Compose(
                [JointResize(in_size), RandomHorizontallyFlip(), RandomRotate(10)]
            )
            img_transform = [transforms.ColorJitter(0.1, 0.1, 0.1)]
            self.mask_transform = transforms.ToTensor()
        else:
            # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
            img_transform = [
                transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR),
            ]
        self.img_transform = transforms.Compose(
            [
                *img_transform,
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                # transforms.Normalize([0.341414, 0.357437, 0.298912], [0.143317, 0.112520, 0.113972]),
            ]
        )

    def __getitem__(self, index):
        img_path, mask_path = self.imgs[index]
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        img = Image.open(img_path).convert("RGB")
        if self.training:
            mask = Image.open(mask_path).convert("L")
            img, mask = self.joint_transform(img, mask)
            img = self.img_transform(img)
            mask = self.mask_transform(mask)
            if self.use_bigt:
                mask = mask.ge(0.5).float()  # 二值化
            return img, mask, img_name
        else:
            # todo: When evaluating, the mask path may not exist. But our code defaults to its existence, which makes
            #  it impossible to use dataloader to generate a prediction without a mask path.
            img = self.img_transform(img)
            # img = img / 255.0
            return img, mask_path, img_name

    def __len__(self):
        return len(self.imgs)

class ImageFolder2(Dataset):  #增加的测试数据导入
    def __init__(self, root, in_size, training, prefix, use_bigt=False):
        self.training = training
        self.use_bigt = use_bigt

        if os.path.isdir(root):
            construct_print(f"{root} is an image folder, we will test on it.")
            self.imgs = _make_dataset2(root)
        elif os.path.isfile(root):
            construct_print(
                f"{root} is a list of images, we will use these paths to read the "
                f"corresponding image"
            )
            self.imgs = _make_dataset_from_list2(root, prefix=prefix)
        else:
            raise NotImplementedError

        # 输入的如果是一个tuple,则按照数据缩放,但是如果是一个数字,则按比例缩放到短边等于该值
        img_transform = [
            transforms.Resize((in_size, in_size), interpolation=Image.BILINEAR),
        ]
        self.img_transform = transforms.Compose(
            [
                *img_transform,
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                # transforms.Normalize([0.341414, 0.357437, 0.298912], [0.143317, 0.112520, 0.113972]),
            ]
        )

    def __getitem__(self, index):
        # print(self.imgs[index][0])
        img_path = self.imgs[index][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        img = Image.open(img_path).convert("RGB")
        
        img = self.img_transform(img)
        return img, img_name

    def __len__(self):
        return len(self.imgs)

class DataLoaderX(DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super(DataLoaderX, self).__iter__())


def _collate_fn(batch, size_list):
    size = random.choice(size_list)
    img, mask, image_name = [list(item) for item in zip(*batch)]
    img = torch.stack(img, dim=0)
    img = interpolate(img, size=(size, size), mode="bilinear", align_corners=False)
    mask = torch.stack(mask, dim=0)
    mask = interpolate(mask, size=(size, size), mode="nearest")
    return img, mask, image_name


def _mask_loader(dataset, shuffle, drop_last, size_list):
    assert float(torch.__version__[:3]) >= 1.2, (
        "If you want to use the pytorch < 1.2, you need to "
        "comment out the line `collate_fn=...` when you set the `size_list` to `None`."
    )
    return DataLoaderX(
        dataset=dataset,
        collate_fn=partial(_collate_fn, size_list=size_list) if size_list else None,
        batch_size=arg_config["batch_size"],
        num_workers=arg_config["num_workers"],
        shuffle=shuffle,
        drop_last=drop_last,
        pin_memory=True,
    )


def create_loader(data_path, training, size_list=None, prefix=(".jpg", ".png"), get_length=False):
    if training:
        construct_print(f"Training on: {data_path}")
        imageset = ImageFolder(
            data_path,
            in_size=arg_config["input_size"],
            prefix=prefix,
            use_bigt=arg_config["use_bigt"],
            training=True,
        )
        loader = _mask_loader(imageset, shuffle=True, drop_last=True, size_list=size_list)
    else:
        construct_print(f"Testing on: {data_path}")
        imageset = ImageFolder2(
            data_path, in_size=arg_config["input_size"], prefix=prefix, training=False,
        )
        loader = _mask_loader(imageset, shuffle=False, drop_last=False, size_list=None)

    if get_length:
        length_of_dataset = len(imageset)
        return loader, length_of_dataset
    else:
        return loader


if __name__ == "__main__":
    loader = create_loader(
        data_path=arg_config["rgb_data"]["tr_data_path"],
        training=True,
        get_length=False,
        size_list=arg_config["size_list"],
    )

    for idx, train_data in enumerate(loader):
        train_inputs, train_masks, *train_other_data = train_data
        print(f"" f"batch: {idx} ", train_inputs.size(), train_masks.size())

3.训练
这个源码主要是用过配置文件控制的下面先说下配置文件
config.py

import os

__all__ = ["proj_root", "arg_config"]

from collections import OrderedDict

proj_root = os.path.dirname(__file__)
datasets_root = "./Dataset/"

#原作者的路径
# ecssd_path = os.path.join(datasets_root, "Saliency/RGBSOD", "ECSSD")
# dutomron_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUT-OMRON")
# hkuis_path = os.path.join(datasets_root, "Saliency/RGBSOD", "HKU-IS")
# pascals_path = os.path.join(datasets_root, "Saliency/RGBSOD", "PASCAL-S")
# soc_path = os.path.join(datasets_root, "Saliency/RGBSOD", "SOC/Test")
# dutstr_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUTS/Train")
# dutste_path = os.path.join(datasets_root, "Saliency/RGBSOD", "DUTS/Test")
#本人测试使用的路径
# dutstr_path = os.path.join(datasets_root, "ECSSD/Train")
ecssdte_path = os.path.join(datasets_root, "ECSSD/Test")
modelte_path = os.path.join(datasets_root, "TEST")
rivertr_path = os.path.join(datasets_root, "RIVER/Train")
riverte_path = os.path.join(datasets_root, "RIVER/Test")
buildtr_path = os.path.join(datasets_root, "BUILD/Train")
buildte_path = os.path.join(datasets_root, "BUILD/Test")

arg_config = {
    
    
    "model": "MINet_VGG16",  # 实际使用的模型,需要在`network/__init__.py`中导入
    "info": "",  # 关于本次实验的额外信息说明,这个会附加到本次试验的exp_name的结尾,如果为空,则不会附加内容。
    "use_amp": False,  # 是否使用amp加速训练
    "resume_mode": "inference",  # the mode for resume parameters: ['train', 'test', 'inference', '']   #这里注意了,由于我改过的缘故,训练选'',测试选inference
    "use_aux_loss": False,  # 是否使用辅助损失, 这个可以设置多个损失函数,需要在solver.py文件里的self.loss_funcs参数里增加
    "save_pre": True,  # 是否保留最终的预测结果
    "epoch_num": 60,  # 训练周期, 0: directly test model
    "lr": 0.001,  # 微调时缩小100倍
    "xlsx_name": "result.xlsx",  # the name of the record file
    # 数据集设置
    "rgb_data": {
    
    
        "tr_data_path": buildtr_path,   #训练路径
        "te_data_list": OrderedDict(
            {
    
    
                # "pascal-s": pascals_path,
                # "ecssd": ecssdte_path,
                # "hku-is": hkuis_path,
                # "duts": dutste_path,
                # "dut-omron": dutomron_path,
                # "soc": soc_path,
                # "river": riverte_path,
                "modelte": buildte_path, #测试路径
            },
        ),
    },
    # 训练过程中的监控信息
    "tb_update": 50,  # >0 则使用tensorboard
    "print_freq": 50,  # >0, 保存迭代过程中的信息
    # img_prefix, gt_prefix,用在使用索引文件的时候的对应的扩展名
    "prefix": (".jpg", ".png"),
    # if you dont use the multi-scale training, you can set 'size_list': None
    # "size_list": [224, 256, 288, 320, 352],
    "size_list": None,  # 不使用多尺度训练
    "reduction": "mean",  # 损失处理的方式,可选“mean”和“sum”
    # 优化器与学习率衰减
    "optim": "adam",  # 自定义部分的学习率
    "weight_decay": 5e-4,  # 微调时设置为0.0001
    "momentum": 0.9,
    "nesterov": False,
    "sche_usebatch": False,
    "lr_type": "poly",
    "warmup_epoch": 1,  # depond on the special lr_type, only lr_type has 'warmup', when set it to 1, it means no warmup.
    "lr_decay": 0.9,  # poly
    "use_bigt": True,  # 训练时是否对真值二值化(阈值为0.5)
    "batch_size": 4,  # 要是继续训练, 最好使用相同的batchsize
    "num_workers": 0,  # 不要太大, 不然运行多个程序同时训练的时候, 会造成数据读入速度受影响
    "input_size": 512,  #图像大小,里面会有resize 大小,和原本图像不一致会自动帮你resize
}

main.py
这个文件我加了infercence的选项,这个和配置文件里对应

import shutil
from datetime import datetime

from config import arg_config, proj_root
from utils.misc import construct_exp_name, construct_path, construct_print, pre_mkdir, set_seed
from utils.solver import Solver

construct_print(f"{datetime.now()}: Initializing...")
construct_print(f"Project Root: {proj_root}")
init_start = datetime.now()

exp_name = construct_exp_name(arg_config)
path_config = construct_path(
    proj_root=proj_root, exp_name=exp_name, xlsx_name=arg_config["xlsx_name"],
)
pre_mkdir(path_config)
set_seed(seed=0, use_cudnn_benchmark=arg_config["size_list"] != None)

solver = Solver(exp_name, arg_config, path_config)
construct_print(f"Total initialization time:{datetime.now() - init_start}")

shutil.copy(f"{proj_root}/config.py", path_config["cfg_log"])
shutil.copy(f"{proj_root}/utils/solver.py", path_config["trainer_log"])

construct_print(f"{datetime.now()}: Start...")
if arg_config["resume_mode"] == "test":
    solver.test()
elif arg_config["resume_mode"] == "inference":    #增加了这里
	solver.inference()
else:
    solver.train()
construct_print(f"{datetime.now()}: End...")

solver.py

import os
from pprint import pprint

import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

import network as network_lib
from loss.CEL import CEL
from loss.focal_loss import FocalLoss #下面这些loss函数都是我加的后面会打包一起给
from loss.dice_loss import DiceLoss  
from loss.iou_loss import IoULoss  
from utils.dataloader import create_loader
from utils.metric import cal_maxf, cal_pr_mae_meanf
from utils.misc import (
    AvgMeter,
    construct_print,
    write_data_to_file,
)
from utils.pipeline_ops import (
    get_total_loss,
    make_optimizer,
    make_scheduler,
    resume_checkpoint,
    save_checkpoint,
)
from utils.recorder import TBRecorder, Timer, XLSXRecoder


class Solver:
    def __init__(self, exp_name: str, arg_dict: dict, path_dict: dict):
        super(Solver, self).__init__()
        self.exp_name = exp_name
        self.arg_dict = arg_dict
        self.path_dict = path_dict

        self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.to_pil = transforms.ToPILImage()

        self.tr_data_path = self.arg_dict["rgb_data"]["tr_data_path"]
        self.te_data_list = self.arg_dict["rgb_data"]["te_data_list"]

        self.save_path = self.path_dict["save"]
        self.save_pre = self.arg_dict["save_pre"]

        if self.arg_dict["tb_update"] > 0:
            self.tb_recorder = TBRecorder(tb_path=self.path_dict["tb"])
        if self.arg_dict["xlsx_name"]:
            self.xlsx_recorder = XLSXRecoder(xlsx_path=self.path_dict["xlsx"])

        # 依赖与前面属性的属性
        self.tr_loader = create_loader(
            data_path=self.tr_data_path,
            training=True,
            size_list=self.arg_dict["size_list"],
            prefix=self.arg_dict["prefix"],
            get_length=False,
        )
        self.end_epoch = self.arg_dict["epoch_num"]
        self.iter_num = self.end_epoch * len(self.tr_loader)

        if hasattr(network_lib, self.arg_dict["model"]):
            self.net = getattr(network_lib, self.arg_dict["model"])().to(self.dev)
        else:
            raise AttributeError
        pprint(self.arg_dict)

        if self.arg_dict["resume_mode"] == "test":
            # resume model only to test model.
            # self.start_epoch is useless
            resume_checkpoint(
                model=self.net, load_path=self.path_dict["final_state_net"], mode="onlynet",
            )
            return
		#因为新加了inference,所以这里也对应加了
        if self.arg_dict["resume_mode"] == "inference":
            # resume model only to test model.
            # self.start_epoch is useless
            resume_checkpoint(
                model=self.net, load_path=self.path_dict["final_state_net"], mode="onlynet",
            )
            return
		
		#可以多个loss,记得把config.py文件对应位置的设置改为True
        self.loss_funcs = [
            # torch.nn.BCEWithLogitsLoss(reduction=self.arg_dict["reduction"]).to(self.dev)
            # FocalLoss()
            IoULoss()
        ]
        if self.arg_dict["use_aux_loss"]:
            self.loss_funcs.append(CEL().to(self.dev))

        self.opti = make_optimizer(
            model=self.net,
            optimizer_type=self.arg_dict["optim"],
            optimizer_info=dict(
                lr=self.arg_dict["lr"],
                momentum=self.arg_dict["momentum"],
                weight_decay=self.arg_dict["weight_decay"],
                nesterov=self.arg_dict["nesterov"],
            ),
        )
        self.sche = make_scheduler(
            optimizer=self.opti,
            total_num=self.iter_num if self.arg_dict["sche_usebatch"] else self.end_epoch,
            scheduler_type=self.arg_dict["lr_type"],
            scheduler_info=dict(
                lr_decay=self.arg_dict["lr_decay"], warmup_epoch=self.arg_dict["warmup_epoch"]
            ),
        )

        # AMP
        if self.arg_dict["use_amp"]:
            construct_print("Now, we will use the amp to accelerate training!")
            from apex import amp

            self.amp = amp
            self.net, self.opti = self.amp.initialize(self.net, self.opti, opt_level="O1")
        else:
            self.amp = None

        if self.arg_dict["resume_mode"] == "train":
            # resume model to train the model
            self.start_epoch = resume_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                amp=self.amp,
                exp_name=self.exp_name,
                load_path=self.path_dict["final_full_net"],
                mode="all",
            )
        else:
            # only train a new model.
            self.start_epoch = 0

    def train(self):
        for curr_epoch in range(self.start_epoch, self.end_epoch):
            train_loss_record = AvgMeter()
            self._train_per_epoch(curr_epoch, train_loss_record)

            # 根据周期修改学习率
            if not self.arg_dict["sche_usebatch"]:
                self.sche.step()

            # 每个周期都进行保存测试,保存的是针对第curr_epoch+1周期的参数
            save_checkpoint(
                model=self.net,
                optimizer=self.opti,
                scheduler=self.sche,
                amp=self.amp,
                exp_name=self.exp_name,
                current_epoch=curr_epoch + 1,
                full_net_path=self.path_dict["final_full_net"],
                state_net_path=self.path_dict["final_state_net"],
            )  # 保存参数
		
		#这里被我注释了,如果要用,需要把dataloader.py 里面的create_loader函数中的ImageFolder2换成ImageFolder
        # if self.arg_dict["use_amp"]:
        #     # https://github.com/NVIDIA/apex/issues/567
        #     with self.amp.disable_casts():
        #         construct_print("When evaluating, we wish to evaluate in pure fp32.")
        #         self.test()
        # else:
        #     self.test()

    @Timer
    def _train_per_epoch(self, curr_epoch, train_loss_record):
        for curr_iter_in_epoch, train_data in enumerate(self.tr_loader):
            num_iter_per_epoch = len(self.tr_loader)
            curr_iter = curr_epoch * num_iter_per_epoch + curr_iter_in_epoch

            self.opti.zero_grad()

            train_inputs, train_masks, _ = train_data
            train_inputs = train_inputs.to(self.dev, non_blocking=True)
            train_masks = train_masks.to(self.dev, non_blocking=True)
            train_preds = self.net(train_inputs)

            train_loss, loss_item_list = get_total_loss(train_preds, train_masks, self.loss_funcs)
            if self.amp:
                with self.amp.scale_loss(train_loss, self.opti) as scaled_loss:
                    scaled_loss.backward()
            else:
                train_loss.backward()
            self.opti.step()

            if self.arg_dict["sche_usebatch"]:
                self.sche.step()

            # 仅在累计的时候使用item()获取数据
            train_iter_loss = train_loss.item()
            train_batch_size = train_inputs.size(0)
            train_loss_record.update(train_iter_loss, train_batch_size)

            # 显示tensorboard
            if (
                self.arg_dict["tb_update"] > 0
                and (curr_iter + 1) % self.arg_dict["tb_update"] == 0
            ):
                self.tb_recorder.record_curve("trloss_avg", train_loss_record.avg, curr_iter)
                self.tb_recorder.record_curve("trloss_iter", train_iter_loss, curr_iter)
                self.tb_recorder.record_curve("lr", self.opti.param_groups, curr_iter)
                self.tb_recorder.record_image("trmasks", train_masks, curr_iter)
                self.tb_recorder.record_image("trsodout", train_preds.sigmoid(), curr_iter)
                self.tb_recorder.record_image("trsodin", train_inputs, curr_iter)
            # 记录每一次迭代的数据
            if (
                self.arg_dict["print_freq"] > 0
                and (curr_iter + 1) % self.arg_dict["print_freq"] == 0
            ):
                lr_str = ",".join(
                    [f"{param_groups['lr']:.7f}" for param_groups in self.opti.param_groups]
                )
                log = (
                    f"{curr_iter_in_epoch}:{num_iter_per_epoch}/"
                    f"{curr_iter}:{self.iter_num}/"
                    f"{curr_epoch}:{self.end_epoch} "
                    f"{self.exp_name}\n"
                    f"Lr:{lr_str} "
                    f"M:{train_loss_record.avg:.5f} C:{train_iter_loss:.5f} "
                    f"{loss_item_list}"
                )
                print(log)
                write_data_to_file(log, self.path_dict["tr_log"])

    def test(self):
        self.net.eval()

        total_results = {
    
    }
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader = create_loader(
                data_path=data_path,
                training=False,
                prefix=self.arg_dict["prefix"],
                get_length=False,
            )
            self.save_path = os.path.join(self.path_dict["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            results = self._test_process(save_pre=self.save_pre)
            msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
            construct_print(msg)
            write_data_to_file(msg, self.path_dict["te_log"])

            total_results[data_name] = results

        self.net.train()

        if self.arg_dict["xlsx_name"]:
            # save result into xlsx file.
            self.xlsx_recorder.write_xlsx(self.exp_name, total_results)

    def _test_process(self, save_pre):
        loader = self.te_loader

        pres = [AvgMeter() for _ in range(256)]
        recs = [AvgMeter() for _ in range(256)]
        meanfs = AvgMeter()
        maes = AvgMeter()

        tqdm_iter = tqdm(enumerate(loader), total=len(loader), leave=False)
        for test_batch_id, test_data in tqdm_iter:
            tqdm_iter.set_description(f"{self.exp_name}: te=>{test_batch_id + 1}")
            with torch.no_grad():
                in_imgs, in_mask_paths, in_names = test_data
                in_imgs = in_imgs.to(self.dev, non_blocking=True)
                outputs = self.net(in_imgs)

            outputs_np = outputs.sigmoid().cpu().detach()

            for item_id, out_item in enumerate(outputs_np):
                gimg_path = os.path.join(in_mask_paths[item_id])
                gt_img = Image.open(gimg_path).convert("L")
                out_img = self.to_pil(out_item).resize(gt_img.size, resample=Image.NEAREST)

                if save_pre:
                    oimg_path = os.path.join(self.save_path, in_names[item_id] + ".png")
                    out_img.save(oimg_path)

                gt_img = np.array(gt_img)
                out_img = np.array(out_img)
                ps, rs, mae, meanf = cal_pr_mae_meanf(out_img, gt_img)
                for pidx, pdata in enumerate(zip(ps, rs)):
                    p, r = pdata
                    pres[pidx].update(p)
                    recs[pidx].update(r)
                maes.update(mae)
                meanfs.update(meanf)
        maxf = cal_maxf([pre.avg for pre in pres], [rec.avg for rec in recs])
        results = {
    
    "MAXF": maxf, "MEANF": meanfs.avg, "MAE": maes.avg}
        return results

	#这里是我加的
    def inference(self):
        self.net.eval()

        total_results = {
    
    }
        for data_name, data_path in self.te_data_list.items():
            construct_print(f"Testing with testset: {data_name}")
            self.te_loader = create_loader(
                data_path=data_path,
                training=False,
                prefix=self.arg_dict["prefix"],
                get_length=False,
            )
            self.save_path = os.path.join(self.path_dict["save"], data_name)
            if not os.path.exists(self.save_path):
                construct_print(f"{self.save_path} do not exist. Let's create it.")
                os.makedirs(self.save_path)
            self._inference_process(save_pre=self.save_pre)
            # msg = f"Results on the testset({data_name}:'{data_path}'): {results}"
            # construct_print(msg)
            # write_data_to_file(msg, self.path_dict["te_log"])

            # total_results[data_name] = results

        # self.net.train()

        # if self.arg_dict["xlsx_name"]:
        #     # save result into xlsx file.
        #     self.xlsx_recorder.write_xlsx(self.exp_name, total_results)

    def _inference_process(self, save_pre):
        loader = self.te_loader
        tqdm_iter = tqdm(enumerate(loader), total=len(loader), leave=False)
        for test_batch_id, test_data in tqdm_iter:
            tqdm_iter.set_description(f"{self.exp_name}: te=>{test_batch_id + 1}")
            with torch.no_grad():
                in_imgs, in_names= test_data
                # print(in_imgs.shape)
                in_imgs = in_imgs.to(self.dev, non_blocking=True)
                outputs = self.net(in_imgs)

            outputs_np = outputs.sigmoid().cpu().detach()

            for item_id, out_item in enumerate(outputs_np):
                out_img = self.to_pil(out_item).resize((256,256), resample=Image.NEAREST)
                if save_pre:
                    oimg_path = os.path.join(self.save_path, in_names[item_id] + ".png")
                    out_img.save(oimg_path)

pipeline_ops.py
这里改了loss获取的函数get_total_loss,自己的loss会报一个错,这么改了以后能用

import os

import torch
import torch.nn as nn
import torch.optim.optimizer as optim
import torch.optim.lr_scheduler as sche
import numpy as np
from torch.optim import Adam, SGD

from utils.misc import construct_print


def get_total_loss(
    train_preds: torch.Tensor, train_masks: torch.Tensor, loss_funcs: list
) -> (float, list):
    """
    return the sum of the list of loss functions with train_preds and train_masks
    
    Args:
        train_preds (torch.Tensor): predictions
        train_masks (torch.Tensor): masks
        loss_funcs (list): the list of loss functions

    Returns: the sum of all losses and the list of result strings

    """
    loss_list = []
    loss_item_list = []

    assert len(loss_funcs) != 0, "请指定损失函数`loss_funcs`"
    for loss in loss_funcs:
        loss_out = loss(train_preds, train_masks)
        try:
            loss_list.append(loss_out)
            loss_item_list.append(f"{loss_out.item():.5f}")
        except:
            loss_list.append(loss_out)
            loss_item_list.append(f"{loss_out:.5f}")

    train_loss = sum(loss_list)
    return train_loss, loss_item_list


def save_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    current_epoch: int = 1,
    full_net_path: str = "",
    state_net_path: str = "",
):
    """
    保存完整参数模型(大)和状态参数模型(小)

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        current_epoch (int): in the epoch, model **will** be trained
        full_net_path (str): the path for saving the full model parameters
        state_net_path (str): the path for saving the state dict.
    """

    state_dict = {
    
    
        "arch": exp_name,
        "epoch": current_epoch,
        "net_state": model.state_dict(),
        "opti_state": optimizer.state_dict(),
        "sche_state": scheduler.state_dict(),
        "amp_state": amp.state_dict() if amp else None,
    }
    torch.save(state_dict, full_net_path)
    torch.save(model.state_dict(), state_net_path)


def resume_checkpoint(
    model: nn.Module = None,
    optimizer: optim.Optimizer = None,
    scheduler: sche._LRScheduler = None,
    amp=None,
    exp_name: str = "",
    load_path: str = "",
    mode: str = "all",
):
    """
    从保存节点恢复模型

    Args:
        model (nn.Module): model object
        optimizer (optim.Optimizer): optimizer object
        scheduler (sche._LRScheduler): scheduler object
        amp (): apex.amp
        exp_name (str): exp_name
        load_path (str): 模型存放路径
        mode (str): 选择哪种模型恢复模式:
            - 'all': 回复完整模型,包括训练中的的参数;
            - 'onlynet': 仅恢复模型权重参数

    Returns mode: 'all' start_epoch; 'onlynet' None
    """
    if os.path.exists(load_path) and os.path.isfile(load_path):
        construct_print(f"Loading checkpoint '{load_path}'")
        checkpoint = torch.load(load_path)
        if mode == "all":
            if exp_name and exp_name != checkpoint["arch"]:
                # 如果给定了exp_name,那么就必须匹配对应的checkpoint["arch"],否则不作要求
                raise Exception(f"We can not match {exp_name} with {load_path}.")

            start_epoch = checkpoint["epoch"]
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint["net_state"])
            else:
                model.load_state_dict(checkpoint["net_state"])
            optimizer.load_state_dict(checkpoint["opti_state"])
            scheduler.load_state_dict(checkpoint["sche_state"])
            if checkpoint.get("amp_state", None):
                if amp:
                    amp.load_state_dict(checkpoint["amp_state"])
                else:
                    construct_print("You are not using amp.")
            else:
                construct_print("The state_dict of amp is None.")
            construct_print(
                f"Loaded '{load_path}' " f"(will train at epoch" f" {checkpoint['epoch']})"
            )
            return start_epoch
        elif mode == "onlynet":
            if hasattr(model, "module"):
                model.module.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)
            construct_print(
                f"Loaded checkpoint '{load_path}' " f"(only has the model's weight params)"
            )
        else:
            raise NotImplementedError
    else:
        raise Exception(f"{load_path}路径不正常,请检查")


def make_scheduler(
    optimizer: optim.Optimizer, total_num: int, scheduler_type: str, scheduler_info: dict
) -> sche._LRScheduler:
    def get_lr_coefficient(curr_epoch):
        nonlocal total_num
        # curr_epoch start from 0
        # total_num = iter_num if args["sche_usebatch"] else end_epoch
        if scheduler_type == "poly":
            coefficient = pow((1 - float(curr_epoch) / total_num), scheduler_info["lr_decay"])
        elif scheduler_type == "poly_warmup":
            turning_epoch = scheduler_info["warmup_epoch"]
            if curr_epoch < turning_epoch:
                # 0,1,2,...,turning_epoch-1
                coefficient = 1 / turning_epoch * (1 + curr_epoch)
            else:
                # turning_epoch,...,end_epoch
                curr_epoch -= turning_epoch - 1
                total_num -= turning_epoch - 1
                coefficient = pow((1 - float(curr_epoch) / total_num), scheduler_info["lr_decay"])
        elif scheduler_type == "cosine_warmup":
            turning_epoch = scheduler_info["warmup_epoch"]
            if curr_epoch < turning_epoch:
                # 0,1,2,...,turning_epoch-1
                coefficient = 1 / turning_epoch * (1 + curr_epoch)
            else:
                # turning_epoch,...,end_epoch
                curr_epoch -= turning_epoch - 1
                total_num -= turning_epoch - 1
                coefficient = (1 + np.cos(np.pi * curr_epoch / total_num)) / 2
        elif scheduler_type == "f3_sche":
            coefficient = 1 - abs((curr_epoch + 1) / (total_num + 1) * 2 - 1)
        else:
            raise NotImplementedError
        return coefficient

    scheduler = sche.LambdaLR(optimizer, lr_lambda=get_lr_coefficient)
    return scheduler


def make_optimizer(model: nn.Module, optimizer_type: str, optimizer_info: dict) -> optim.Optimizer:
    if optimizer_type == "sgd_trick":
        # https://github.com/implus/PytorchInsight/blob/master/classification/imagenet_tricks.py
        params = [
            {
    
    
                "params": [
                    p for name, p in model.named_parameters() if ("bias" in name or "bn" in name)
                ],
                "weight_decay": 0,
            },
            {
    
    
                "params": [
                    p
                    for name, p in model.named_parameters()
                    if ("bias" not in name and "bn" not in name)
                ]
            },
        ]
        optimizer = SGD(
            params,
            lr=optimizer_info["lr"],
            momentum=optimizer_info["momentum"],
            weight_decay=optimizer_info["weight_decay"],
            nesterov=optimizer_info["nesterov"],
        )
    elif optimizer_type == "sgd_r3":
        params = [
            # 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络
            # 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达
            # 到减少模型过拟合的效果。
            {
    
    
                "params": [
                    param for name, param in model.named_parameters() if name[-4:] == "bias"
                ],
                "lr": 2 * optimizer_info["lr"],
            },
            {
    
    
                "params": [
                    param for name, param in model.named_parameters() if name[-4:] != "bias"
                ],
                "lr": optimizer_info["lr"],
                "weight_decay": optimizer_info["weight_decay"],
            },
        ]
        optimizer = SGD(params, momentum=optimizer_info["momentum"])
    elif optimizer_type == "sgd_all":
        optimizer = SGD(
            model.parameters(),
            lr=optimizer_info["lr"],
            weight_decay=optimizer_info["weight_decay"],
            momentum=optimizer_info["momentum"],
        )
    elif optimizer_type == "adam":
        optimizer = Adam(
            model.parameters(),
            lr=optimizer_info["lr"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=optimizer_info["weight_decay"],
        )
    elif optimizer_type == "f3_trick":
        backbone, head = [], []
        for name, params_tensor in model.named_parameters():
            if name.startswith("div_2"):
                pass
            elif name.startswith("div"):
                backbone.append(params_tensor)
            else:
                head.append(params_tensor)
        params = [
            {
    
    "params": backbone, "lr": 0.1 * optimizer_info["lr"]},
            {
    
    "params": head, "lr": optimizer_info["lr"]},
        ]
        optimizer = SGD(
            params=params,
            momentum=optimizer_info["momentum"],
            weight_decay=optimizer_info["weight_decay"],
            nesterov=optimizer_info["nesterov"],
        )
    else:
        raise NotImplementedError

    print("optimizer = ", optimizer)
    return optimizer


if __name__ == "__main__":
    a = torch.rand((3, 3)).bool()
    print(isinstance(a, torch.FloatTensor), a.type())

4.预测
训练完以后自动生成一个ouput文件夹,当你config.py文件都设置好以后这个会自动生成配置很多东西,记得测试要设置"resume_mode": “inference”,结果存储的位置也在output里的pre文件夹中
下面是我传到百度网盘的参考,数据前面的博客提供了,这里面没放数据
链接:https://pan.baidu.com/s/1n1gfGEIm9kibVAwv8ifKNA
提取码:7477
复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V5的分享

猜你喜欢

转载自blog.csdn.net/qq_20373723/article/details/112739902