基于深度学习的显著性检测用于遥感影像地物提取(CPD)

这个跑起来也简单,可惜又不是我想找的,先记录下来,我的方向一定是错了,下一篇还是复现显著性检测,效果还挺好的,可以先去下一篇看看,这一篇主要是复现不麻烦,自己改应该也好改。
使用链接:https://github.com/wuzhe71/CPD
效果比上一篇的poolnet差不少,都是随便跑的,没有调参
图像
原图
标签
标签
预测结果
结果
精度评定:

acc:  0.7580808851453993
acc_cls:  0.8396931347334586
iou:  [0.70454195 0.42825706]
miou:  0.566399504319334
fwavacc:  0.6529282073695303
class_accuracy:  0.4340081922720089
class_recall:  0.9699864979626694
accuracy:  0.7580808851453993
f1_score:  0.5996918498867934

精度评定不是很好,我记得是我换了个loss函数,有兴趣换回去再试试吧。

1.数据结构
目录结构很简单,就是只有图像和标签文件夹,图像名字和标签名字一样就行,另外这个项目没有在训练时验证模型,所以不需要valid数据,可以合并。
在这里插入图片描述
2.训练
训练文件改的地方不多,主要是数据的路径,其次就是编码模型的选择,作者提供了resnet和vgg两种,通过is_ResNet参数控制,跑之前记得去GitHub上下载一下预模型,后面我会打包一份代码,里面也有,我改了名字的,但能看出来。模型存储会自动创建路径,放在models里面了。
train.py

import torch
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
import pdb, os, argparse
from datetime import datetime

from model.CPD_models import CPD_VGG
from model.CPD_ResNet_models import CPD_ResNet
from data import get_loader
from utils import clip_gradient, adjust_lr
from loss.focal_loss import FocalLoss
from loss.BCE_Dice_loss import DiceBCELoss #注意我在loss文件夹里放了很多loss函数,请自行取用

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--batchsize', type=int, default=8, help='training batch size')
parser.add_argument('--trainsize', type=int, default=512, help='training dataset size')
parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
parser.add_argument('--is_ResNet', type=bool, default=True, help='VGG or ResNet backbone')
parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
opt = parser.parse_args()

print('Learning Rate: {} ResNet: {}'.format(opt.lr, opt.is_ResNet))
# build models
if opt.is_ResNet:
    model = CPD_ResNet()
else:
    model = CPD_VGG()

model.cuda()
params = model.parameters()
optimizer = torch.optim.Adam(params, opt.lr)

# image_root = './data/RIVER/Train/Image/'
# gt_root = './data/RIVER/Train/Mask/'
image_root = 'D:/wcs/PoolNet-master/Data/build/train/images/'
gt_root = 'D:/wcs/PoolNet-master/Data/build/train/labels/'
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
total_step = len(train_loader)

CE = torch.nn.BCEWithLogitsLoss()
# CE = DiceBCELoss()

def train(train_loader, model, optimizer, epoch):
    model.train()
    # model.load_state_dict(torch.load('./CPD_vgg16.pth'))
    for i, pack in enumerate(train_loader, start=1):
        optimizer.zero_grad()
        images, gts = pack
        images = Variable(images)
        gts = Variable(gts)
        images = images.cuda()
        gts = gts.cuda()

        atts, dets = model(images)
        loss1 = CE(atts, gts)
        loss2 = CE(dets, gts)
        loss = loss1 + loss2
        loss.backward()

        clip_gradient(optimizer, opt.clip)
        optimizer.step()

        if i % 400 == 0 or i == total_step:
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
                  format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))

    if opt.is_ResNet:
        save_path = 'models/CPD_Resnet/'
    else:
        save_path = 'models/CPD_VGG/'

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if (epoch+1) % 5 == 0:
        torch.save(model.state_dict(), save_path + 'CPD_BCE.pth' + '.%d' % epoch)

print("Let's go!")
for epoch in range(1, opt.epoch):
    adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
    train(train_loader, model, optimizer, epoch)

3.预测
预测代码改的较多,原始的需要同时输入图像的标签,但是有时候我们就像只预测一下,哪还会提供标签,甚至连画都没画
test.py

import torch
import torch.nn.functional as F

import numpy as np
import pdb, os, argparse
from scipy import misc #这个不能用了,换成cv2吧
import cv2

from model.CPD_models import CPD_VGG
from model.CPD_ResNet_models import CPD_ResNet
from data import test_dataset, test_dataset2 #这里注意我在data.py里加入了专门测试的函数,很简单

parser = argparse.ArgumentParser()
parser.add_argument('--testsize', type=int, default=352, help='testing size')
parser.add_argument('--is_ResNet', type=bool, default=True, help='VGG or ResNet backbone')
opt = parser.parse_args()

dataset_path = './Data/build/test'  #数据路径,和下面会有拼接,注意一下

if opt.is_ResNet:
    model = CPD_ResNet()
    model.load_state_dict(torch.load('./models/CPD_Resnet/CPD_BCE.pth.64'))  #这里放训练的模型路径
else:
    model = CPD_VGG()
    model.load_state_dict(torch.load('./models/CPD_VGG/CPD.pth.74'))

model.cuda()
model.eval()

image_root = dataset_path + '/images/'
save_path = dataset_path + '/pre/'
# test_loader = test_dataset(image_root, gt_root, opt.testsize)
test_loader = test_dataset2(image_root, opt.testsize)
for i in range(test_loader.size):
    image, name = test_loader.load_data()
    image = image.cuda()
    _, res = model(image)
    res = F.upsample(res, size=256, mode='bilinear', align_corners=False) #注意这里面的size我直接给了大小的,这个地方就是要输出的大小,别忘了改,或者之际image.shape[0]也可以,我忘了改了
    res = res.sigmoid().data.cpu().numpy().squeeze()
    res = (res - res.min()) / (res.max() - res.min() + 1e-8)
    # misc.imsave(save_path+name, res)
    res = res * 255
    cv2.imwrite(save_path+name, res)

data.py
这里面我加了一个test_dataset2函数,别的没动。

import os
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms


class SalObjDataset(data.Dataset):
    def __init__(self, image_root, gt_root, trainsize):
        self.trainsize = trainsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.filter_files()
        self.size = len(self.images)
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        gt = self.binary_loader(self.gts[index])
        image = self.img_transform(image)
        gt = self.gt_transform(gt)
        return image, gt

    def filter_files(self):
        assert len(self.images) == len(self.gts)
        images = []
        gts = []
        for img_path, gt_path in zip(self.images, self.gts):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            if img.size == gt.size:
                images.append(img_path)
                gts.append(gt_path)
        self.images = images
        self.gts = gts

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            # return img.convert('1')
            return img.convert('L')

    def resize(self, img, gt):
        assert img.size == gt.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
        else:
            return img, gt

    def __len__(self):
        return self.size


def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=0, pin_memory=True):

    dataset = SalObjDataset(image_root, gt_root, trainsize)
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=batchsize,
                                  shuffle=shuffle,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory)
    return data_loader


class test_dataset:
    def __init__(self, image_root, gt_root, testsize):
        self.testsize = testsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.transform = transforms.Compose([
            transforms.Resize((self.testsize, self.testsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.ToTensor()
        self.size = len(self.images)
        self.index = 0

    def load_data(self):
        image = self.rgb_loader(self.images[self.index])
        image = self.transform(image).unsqueeze(0)
        gt = self.binary_loader(self.gts[self.index])
        name = self.images[self.index].split('/')[-1]
        if name.endswith('.jpg'):
            name = name.split('.jpg')[0] + '.png'
        self.index += 1
        return image, gt, name

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')


class test_dataset2:
    def __init__(self, image_root, testsize):
        self.testsize = testsize
        self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
        self.images = sorted(self.images)
        self.transform = transforms.Compose([
            transforms.Resize((self.testsize, self.testsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.size = len(self.images)
        self.index = 0

    def load_data(self):
        image = self.rgb_loader(self.images[self.index])
        image = self.transform(image).unsqueeze(0)
        name = self.images[self.index].split('/')[-1]
        if name.endswith('.jpg'):
            name = name.split('.jpg')[0] + '.png'
        self.index += 1
        return image, name

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')

下面是代码链接,数据上一篇提供了建筑的,去一起下载把
链接:https://pan.baidu.com/s/18wpUAyO65IiONqHuyG94OA
提取码:8eyg
复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V5的分享

猜你喜欢

转载自blog.csdn.net/qq_20373723/article/details/112723687