多光谱遥感分类:使用CNN1(一)

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/nima1994/article/details/82313076

代码源于很久以前练手的一个Demo,时间长了许多魔改版的都不见了,目前只剩下此简陋版本。读者如有相关需求,可根据只言片语断章取义。由于代码混乱基础,不再上传GitHub。

所用数据为多光谱遥感影像(.tif),抠图所得文件(.shp)。

工具篇

根据shp文件(样本点),对栅格图像的3、2、1波段切图,并保存在相应标签下的文件夹,注意shp、tif的投影坐标一致

from osgeo import gdal
import numpy as np
import shapefile
import cv2
import os

size=64
bands=3

dataset = gdal.Open(r"E:\数据2\test_tif_peizhun_subset_proj_.tif")
rer=shapefile.Reader(r'E:\shps\test.shp')

def __createDir(path):
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except:
            print("创建文件夹失败")
            exit(1)

def __getACell(geo,pos):
    try:
        xoffset = int((pos[0] - geo[0]) / geo[1])
        yoffset = int((pos[1] - geo[3]) / geo[5])

        print("pixels: x= %d,y= %d" % (xoffset, yoffset))
        output = []
        for i in [3,2,1]:
            band = dataset.GetRasterBand(i)
            if (int(xoffset - size / 2) < 0 or int(yoffset - size / 2) < 0
                    or int(xoffset - size / 2) + size > dataset.RasterXSize
                    or int(yoffset - size / 2) + size > dataset.RasterYSize):
                return None
            t = band.ReadAsArray(int(xoffset - size / 2), int(yoffset - size / 2), size, size)
            output.append(t)
        img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
    except:
        return None
    return img

def getShpDataForNum():

    labels=[i[0] for i in rer.records()]
    for i in set(labels):
        __createDir(os.path.join("data/org/"+str(i)))

    for i in range(rer.numRecords):#rer.numRecords
        print("deal %d: " % (i+1))
        sr=rer.shape(i)
        img=__getACell(dataset.GetGeoTransform(), sr.points[0])
        if(img is None):
            print("the area of points %d is out range." %(i))
            continue
        label=labels[i]
        cv2.imwrite("data/org/%s/%s.%d.jpg" % (label, label, i), img)
        print("data/org/%s/%s.%d.jpg" % (label, label, i))
    print("deal finish,to numpy array.")

getShpDataForNum()

如下,将上述所得文件拆分为测试集和训练集。

import os
import shutil
import random

def createDir(path):
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except:
            print("创建文件夹失败")
            exit(1)

createDir("data/train/")
createDir("data/test/")


dir='data/org/'
for dir_item in os.listdir(dir):

    createDir("data/train/" + dir_item)
    createDir("data/test/"+dir_item)

    org_data=os.listdir(dir+dir_item+"/")
    random.shuffle(org_data)
    num=int(len(org_data)*0.25)

    print(dir + dir_item + " start.")
    for d in org_data[:-num]:
        shutil.copyfile(dir + dir_item + "/" + d, "data/train/" + dir_item + "/" + d)
    for d in org_data[-num:]:
        shutil.copyfile(dir+dir_item+"/"+d,"data/test/"+dir_item+"/"+d)
    print(dir+dir_item+" finished")

以下显示制定文件夹下的子文件夹中的文件数目直方图。

import os
import seaborn as sns
import matplotlib.pyplot as plt
def show(path,title):
    d=os.listdir(path)
    d_len=[len(os.listdir(os.path.join(path,i))) for i in d]

    # print(d,d_len)

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
    sns.barplot(d,d_len,)
    plt.xlabel("样本类型")
    plt.ylabel("数量")
    plt.title(title)

    for i in range(len(d_len)):
        plt.text(i,d_len[i]+2,"%d" % d_len[i],ha="center",va="bottom")
    plt.show()

show(r"data/1_train","训练集源数据采样集")

由于其他原因,数据更改。如下为使用shp样本点对应的像素坐标所采图集。此时分为train pos.txt和test pos.txt诸如此类。

from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2,shutil


class Tiff:
    def createDir(self, path):
        if not os.path.exists(path):
            try:
                os.makedirs(path)
            except:
                print("创建文件夹失败")
                exit(1)

    def __init__(self,  pos_src,other_feather,contact_src,size=128,bands=[3,2,1],tif_src=r"D:/lishihang/jiangxia_simple/ZY3_GS_jiangxia1.tif"):

        self.dataset = gdal.Open(tif_src)  # tif数据
        self.size = size  # 采样窗口大小
        self.bands=bands
        self.contact_pos_feather(pos_src, other_feather,contact_src)
        self.fea =pd.read_csv(contact_src, header=None)
        # shutil.rmtree("data/temp.txt")

    def get_cell(self, pos_x, pos_y):
        try:
            output = []
            for i in self.bands:
                band = self.dataset.GetRasterBand(i)
                t = band.ReadAsArray(int(pos_x - self.size / 2), int(pos_y - self.size / 2), self.size, self.size)
                output.append(t)

            img2 = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
            # print(img2.shape)
            # self.showImg(img2)
        except:
            return None
        return img2

    def get_cells(self,target_src):
        fea_len=len(self.fea)


        self.createDir(target_src)
        for label in set(self.fea.iloc[:,-2]):
            self.createDir("%s/%s" % (target_src,label))

        print("fea length: %d" % fea_len)

        for i in range(fea_len):
            temp=self.fea.iloc[i,:].values
            img = self.get_cell(temp[1], temp[0])
            if img is None:
                continue
            cv2.imwrite("%s/%s/%s.%d.jpg" % (target_src,temp[-2], temp[-2], i), img)
            if(i%1000==0):
                print("%d/%d hava finsh save." % (i,fea_len))

    def contact_pos_feather(self,pos_src, other_feather,target):
        if os.path.exists(target):
            print("文件已存在")
            return
        pos = pd.read_csv(pos_src, header=None, sep=' ')
        feather = pd.read_csv(other_feather, header=None, sep='\t')
        # fea = pd.concat([pos, feather], axis=1).sample(frac=1).reset_index(drop=True)
        fea = pd.concat([pos, feather], axis=1)
        print("pos Length=%d,feather Length=%d,fea Length=%d" % (len(pos), len(feather), len(fea)))
        # print(type(fea))
        del feather
        del pos
        fea = pd.DataFrame(fea)
        fea.to_csv(target, index=None, header=None)



if __name__ == '__main__':
    tiff=Tiff(r"D:/tr_sample_1.txt",r"D:/train1.txt",r"tr_1.txt")
    # tiff=Tiff(r"D:/te_sample_1.txt",r"D:/test1.txt",r"te_1.txt")
    # tiff.get_cells("data/1_test")

猜你喜欢

转载自blog.csdn.net/nima1994/article/details/82313076