仅个人使用:决策融合

 
# import os
# from PIL import Image
# import cv2
# import numpy as np
 
 
# rgb_path='/mnt/sdb1/fenghaixia/jcrhsix/rgb/'
# h_path='/mnt/sdb1/fenghaixia/jcrhsix/dsm/'
# hrgb_path='/mnt/sdb1/fenghaixia/jcrhsix/hhhrgb/'
# savepath='/mnt/sdb1/fenghaixia/jcrhsix/aver/'
# filelist = os.listdir(rgb_path)
# for f in filelist:
#     rgb_name=rgb_path+f.strip()
#     h_name=h_path+f.strip()
#     hrgb_name=hrgb_path+f.strip()
#     # print(os.path.exists(rgb_name))
#     # print(os.path.exists(h_name))
    
#     # print(os.path.exists(hrgb_name))
#     if os.path.exists(rgb_name) and os.path.exists(h_name) and os.path.exists(hrgb_name):
#         # im = Image.open(path + item) #打开图片
#         rgb_im = cv2.imread(rgb_name)
#         h_im=cv2.imread(h_name)
#         hrgb_im=cv2.imread(hrgb_name)
#         aver_im=(rgb_im+h_im+hrgb_im)/3.0
#         aver_im[aver_im>4.0]=255
#         aver_im[aver_im<=4.0]=0
#         cv2.imwrite(savepath + f, aver_im)
#         print(f)
# print('finish')
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V
 
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import random
import shutil
from matplotlib.pyplot import MultipleLocator
#从pyplot导入MultipleLocator类,这个类用于设置刻度间隔
 
 
from time import time
from PIL import Image
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_IoU
 
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
 
BATCHSIZE_PER_CARD = 16
 
class TTAFrame():
    def __init__(self, net):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
        
    def test_one_img_from_path(self, path, evalmode = True):
        if evalmode:
            self.net.eval()
        batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
        if batchsize >= 8:
            return self.test_one_img_from_path_1(path)
 
    def test_one_img_from_path_1(self, path):
        img = cv2.imread(path)#.transpose(2,0,1)[None]
        
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None],img90[None]])
        img2 = np.array(img1)[:,::-1]
        img3 = np.concatenate([img1,img2])
        img4 = np.array(img3)[:,:,::-1]
        img5 = np.concatenate([img3,img4]).transpose(0,3,1,2)
        img5 = np.array(img5, np.float32)/255.0 * 3.2 -1.6
        img5 = V(torch.Tensor(img5).cuda())
        
        mask = self.net.forward(img5).squeeze().cpu().data.numpy()#.squeeze(1)
        mask1 = mask[:4] + mask[4:,:,::-1]
        mask2 = mask1[:2] + mask1[2:,::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1,::-1]
        
        return mask3
 
    def load(self, path):
     #   new_state_dict = OrderedDict()
      #  for key, value in torch.load(path).items():
       #     name = 'module.' + key
        #    new_state_dict[name] = value
        #model.load_state_dict(new_state_dict)
        #model = torch.load(path)
        #model.pop('module.finaldeconv1.weight')
        #model.pop('module.finalconv3.weight')
        #self.net.load_state_dict(model,strict=False)
        self.net.load_state_dict(torch.load(path))
# source = 'dataset/test/'
 
def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close
 
def savetrainList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt_train.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close
 
def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))
 
 
print("开始运行!")
 
 
mylog = open('submits/count_low_pic.log','w')

#wtn:精度计算
miou_mode       = 2
#------------------------------#
#   分类个数+1、如2+1
#------------------------------#
num_classes     = 2
#--------------------------------------------#
#   区分的种类,和json_to_dataset里面的一样
#--------------------------------------------#
# name_classes    = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
name_classes    = ["nonwater","water"]
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
# data_path  = './dataset/'
# data_train_path='./dataset/'


f=open("./dataset/gt.txt", 'w')
gt_dir      = '/mnt/sdb1/fenghaixia/dddrgb/dataset/real/'
pred_dir    = '/mnt/sdb1/fenghaixia/jcrhsix/aver/'
path_list = os.listdir('/mnt/sdb1/fenghaixia/jcrhsix/aver/')
path_list.sort()
dirList('/mnt/sdb1/fenghaixia/jcrhsix/aver/',path_list)
saveList(path_list)
image_ids   = open(os.path.join('./dataset/', "gt.txt"),'r').read().splitlines() 


train_mIou=[]
train_mPA=[]
test_mIou=[]
test_mPA=[]

if miou_mode == 0 or miou_mode == 2:
    


    print('计算测试miou')
    test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes)  # 执行计算mIoU的函数
    mylog.write('  test_mIoU:  '+str(test_miou))
    mylog.write('  test_mPA:  '+str(test_mpa))
    print('  test_mIoU:  '+str(test_miou))
        

    # count=0
    # print('计算测试样本单张iou')
    # count=compute_IoU(gt_dir, pred_dir, image_ids, num_classes, lower_iou,higher_iou,weight_name,count)  # 执行计算mIoU的函数
    # mylog.write('  low-iou test picture num:  '+str(count))
    # print(weight_name + "Get miou done.")

    

mylog.write('Finish!')
print ('Finish!')
mylog.close()

3选2 

 
import os
from PIL import Image
import cv2
import numpy as np
 
 
rgb_path='/mnt/sdb1/fenghaixia/jcrhsix/rgb/'
h_path='/mnt/sdb1/fenghaixia/jcrhsix/dsm/'
hrgb_path='/mnt/sdb1/fenghaixia/jcrhsix/hhhrgb/'
savepath='/mnt/sdb1/fenghaixia/jcrhsix/aver/'
filelist = os.listdir(rgb_path)
for f in filelist:
    rgb_name=rgb_path+f.strip()
    h_name=h_path+f.strip()
    hrgb_name=hrgb_path+f.strip()
    # print(os.path.exists(rgb_name))
    # print(os.path.exists(h_name))
    
    # print(os.path.exists(hrgb_name))
    if os.path.exists(rgb_name) and os.path.exists(h_name) and os.path.exists(hrgb_name):
        # im = Image.open(path + item) #打开图片
        rgb_im = cv2.imread(rgb_name)
        h_im=cv2.imread(h_name)
        hrgb_im=cv2.imread(hrgb_name)
        rgb_im[rgb_im>=4.0]=10
        rgb_im[rgb_im<4.0]=0
        h_im[h_im>=4.0]=10
        h_im[h_im<4.0]=0
        hrgb_im[hrgb_im>=4.0]=10
        hrgb_im[hrgb_im<4.0]=0
        all=  rgb_im  +  h_im + hrgb_im
        all[all>=20.0]=255
        all[all<20.0]=0
        cv2.imwrite(savepath + f, all)
        print(f)
print('finish')

猜你喜欢

转载自blog.csdn.net/weixin_61235989/article/details/130517415