遥感分类的一种采样方法

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/nima1994/article/details/83148087

如深度学习,输入要求为一小邻域(下文称邻域块)代表中心像素类型。现有栅格图像,以及抠图面文件(.shp)。以下主要集中与arcgis操作。阅读本文前,建议阅读多光谱遥感分类:使用CNN1(一)

一种方法是使用随机点,但是就本任务目标其弊端明显(邻域块重叠相关,可通过设置随机点间隔解决,但会使样本大大减少)。具体参考
在这里插入图片描述

本文将描述的方法基于渔网。通过创建渔网(设置像元间隔)->叠加分析.相交。可以得到近可能多的点。
在这里插入图片描述

本文相关代码如下,读者有必要自行取舍。(包括将点shp文件导出坐标文件topos,根据坐标文件采样到各个文件夹)。

"""
@file: dataCreate.py
@time: 2018/10/15
"""
import os

import shapefile
import gdal
import pandas as pd
import matplotlib.pyplot as plt
from pyecharts import Bar
import numpy as np
import cv2
import shutil
import sys

dataset = gdal.Open(r"E:\Experiment\Mine\dataformat.tif")
rer=shapefile.Reader(r"E:\Experiment\Mine\相交.shp")
size=64

def topos():
    res=[]
    geo=dataset.GetGeoTransform()
    for i in range(rer.numRecords):#rer.numRecords
        pos=rer.shape(i).points[0]

        label=rer.record(i)[1]
        xoffset = int((pos[0] - geo[0]) / geo[1])
        yoffset = int((pos[1] - geo[3]) / geo[5])
        res.append([xoffset,yoffset,label])

    res=pd.DataFrame(res)
    print(res.head())
    res.to_csv("../output/pos.csv",header=None,index=None)

def get_cell(pos_x, pos_y):
    try:
        output = []
        for i in [1,2,3]:
            band = dataset.GetRasterBand(i)
            if (int(pos_x - size / 2) < 0 or int(pos_y - size / 2) < 0
                    or int(pos_x - size / 2) + size > dataset.RasterXSize
                    or int(pos_y - size / 2) + size > dataset.RasterYSize):
                return None
            t = band.ReadAsArray(int(pos_x - size / 2), int(pos_y - size / 2), size, size)
            output.append(t)
        img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
    except:
        return None
    return img

def createdoc(dic):
    d="../output/img"
    if os.path.exists(d):
        shutil.rmtree(d)
        os.makedirs(d)
    for i in dic.values():
        os.makedirs(os.path.join(d,i))

def toImg():
    data=pd.read_csv("../output/pos.csv",header=None)
    labeldic=num2label()

    createdoc(labeldic)
    for line,row in data.iterrows():

        img=get_cell(row[0],row[1])
        if img is None:
            continue
        label=labeldic.get(row[2])

        # cv2.imwrite("../output/img/%s/%d.png" % (label,line),img)
        cv2.imencode('.png', img)[1].tofile("../output/img/%s/%d.png" % (label,line))

def num2label():
    eo=pd.read_csv(r"E:\Experiment\Mine\Export_Output.txt",index_col=0)
    dic=dict()
    for _,row in eo.iterrows():
        dic[row[1]]=row[2]
    return dic


if __name__ == '__main__':

    # toImg()
    # info()
    pass

采样结果如下:
在这里插入图片描述

[[‘城乡居民建设用地_红白顶’ ‘排土场’ ‘未利用土地_裸土地’ ‘水体’ ‘排土场’ ‘排土场’]
[‘城乡居民建设用地_灰白顶’ ‘采场’ ‘林地_灰’ ‘采场’ ‘采场’ ‘排土场’]
[‘耕地_旱地_绿色’ ‘水体’ ‘耕地_旱地_绿色’ ‘林地_红’ ‘排土场’ ‘选矿场’]
[‘选矿场’ ‘耕地_旱地_灰色’ ‘水体’ ‘选矿场’ ‘耕地_旱地_灰色’ ‘林地_红’]
[‘选矿场’ ‘耕地_旱地_灰色’ ‘采场’ ‘采场’ ‘选矿场’ ‘林地_红’]
[‘选矿场’ ‘采场’ ‘采场’ ‘选矿场’ ‘林地_黑’ ‘排土场’]]

其可视化代码如下:

"""
@file: tongji.py
@time: 2018/10/16
"""
import os
import sys
import re
import numpy as np

import torch
import torchvision
from pyecharts import Bar
import random
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import datasets

from dataDeal.datacreate import num2label
from PIL import Image

plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

def show():
    """
    显示文件夹子文件夹下图片统计表
    :return:
    """
    path=r"../output/img"
    d=os.listdir(path)
    d_len=[len(os.listdir(os.path.join(path,i))) for i in d]

    line = Bar(path)
    line.add("图片数量", d, d_len, mark_point=["average", "max", "min"],xaxis_rotate=50)
    line.render(path="../output/图片数量.html")

def info():
    path="../output/pos.csv"
    res=pd.read_csv(path,header=None)

    xy=res.iloc[:, 2].value_counts()
    label=[num2label().get(i) for i in xy.index.values]

    # print(xy.index.values)
    line = Bar(path)
    line.add("点数量", label,xy.values, is_smooth=True, mark_line=["max", "average","min"],xaxis_rotate=50)
    line.render("../output/点数量.html")

def imageshow():

    # 所有路径、标签
    path = r"../output/img"

    # imgall=[]
    # labelall=[]
    # for root, dirs, files in os.walk(path):  # 目录
    #     for f in files:
    #         p = os.path.join(root, f)
    #         label=re.split("/|\\\\",p)[-2]
    #         imgall.append(p)
    #         labelall.append(label)
    # print(len(imgall),len(labelall))
    #
    # row=4
    # col=4
    # pos=random.sample(range(len(labelall)),row*col)
    # # print(pos)
    # plt.figure(figsize=(10,8))
    # for i,value in enumerate(pos):
    #     plt.subplot(row,col,i+1)
    #     plt.imshow(Image.open(imgall[value]))
    #     plt.title(labelall[value])
    #     plt.xticks([])
    #     plt.yticks([])
    #
    # plt.show()

    image_datasets =datasets.ImageFolder(os.path.join(path),transform=torchvision.transforms.ToTensor())
    dataloaders =torch.utils.data.DataLoader(image_datasets, batch_size=36,shuffle=True)

    inputs, classes = next(iter(dataloaders))
    labels=[image_datasets.classes[i] for i in classes.numpy()]
    print(np.array(labels).reshape(-1,6))
    out = torchvision.utils.make_grid(inputs,6,0)
    inp = out.numpy().transpose((1, 2, 0))
    plt.imshow(inp)
    plt.show()

if __name__ == '__main__':

    # show()
    imageshow()

    pass

猜你喜欢

转载自blog.csdn.net/nima1994/article/details/83148087
今日推荐