基于python脚本的遥感影像样例模板批量生产

前言

在实际生产生活中,我们对训练深度学习模型有大量样本的需求,而遥感影像通常很大,为了便于自动化批量生产切片,本篇博文将简单介绍下这个方法。

数据概览

原始影像

在这里插入图片描述

锚点

在这里插入图片描述

样例模板影像

在这里插入图片描述

代码

# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from osgeo import ogr
import os, sys
import numpy as np
import cv2
import numpy
import gdal
import time
import glob
from osgeo import osr


def del_file(path):
    for i in os.listdir(path):
        path_file = os.path.join(path, i)
        if os.path.isfile(path_file):
            os.remove(path_file)
        else:
            del_file(path_file)


def sampleClip(shp, tif, outputdir, sampletype, size, fieldName='cls', n=None):
    time1 = time.clock()
    # if not os.path.exists(outputdir):
    #     os.mkdir(outputdir)
    # else:
    #     del_file(outputdir)

    gdal.AllRegister()
    lc = gdal.Open(tif)
    im_width = lc.RasterXSize
    im_height = lc.RasterYSize
    im_geotrans = lc.GetGeoTransform()
    bandscount = lc.RasterCount
    im_proj = lc.GetProjection()
    print(im_width, im_height)
    gdal.AllRegister()
    gdal.SetConfigOption("gdal_FILENAME_IS_UTF8", "YES")

    driver = ogr.GetDriverByName('ESRI Shapefile')
    dsshp = driver.Open(shp, 0)
    if dsshp is None:
        print('Could not open ' + 'sites.shp')
        sys.exit(1)
    layer = dsshp.GetLayer()
    xValues = []
    yValues = []
    m = layer.GetFeatureCount()
    feature = layer.GetNextFeature()
    print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount, m, sampletype,
                                                                                      int(size)))

    if n is not None:
        pass
    else:
        n = 1
    while feature:
        if n > 0 and n < 10:
            dirname = "000000" + str(n)
        elif n > 9 and n < 100:
            dirname = "00000" + str(n)
        elif n > 99 and n < 1000:
            dirname = "0000" + str(n)
        else:
            dirname = "000" + str(n)
        # print dirname
        dirpath = os.path.join(outputdir, dirname + "_V1")
        if not os.path.exists(dirpath):
            os.mkdir(dirpath)
        tifname = dirname + ".tif"
        if "poly" in sampletype or "POLY" in sampletype:
            shpname = dirname + "_V1_POLY.shp"
        if "line" in sampletype or "LINE" in sampletype:
            shpname = dirname + "_V1_LINE.shp"
        geometry = feature.GetGeometryRef()
        x = geometry.GetX()
        y = geometry.GetY()
        print(x, y)
        print(im_geotrans)
        xValues.append(x)
        yValues.append(y)
        newform = []
        newform = list(im_geotrans)
        # print newform
        newform[0] = x - im_geotrans[1] * int(size) / 2.0
        newform[3] = y - im_geotrans[5] * int(size) / 2.0
        print(newform[0], newform[3])
        newformtuple = tuple(newform)
        x1 = x - int(size) / 2 * im_geotrans[1]
        y1 = y - int(size) / 2 * im_geotrans[5]
        x2 = x + int(size) / 2 * im_geotrans[1]
        y2 = y - int(size) / 2 * im_geotrans[5]
        x3 = x - int(size) / 2 * im_geotrans[1]
        y3 = y + int(size) / 2 * im_geotrans[5]
        x4 = x + int(size) / 2 * im_geotrans[1]
        y4 = y + int(size) / 2 * im_geotrans[5]
        Xpix = (x1 - im_geotrans[0]) / im_geotrans[1]
        # Xpix=(newform[0]-im_geotrans[0])

        Ypix = (newform[3] - im_geotrans[3]) / im_geotrans[5]
        # Ypix=abs(newform[3]-im_geotrans[3])
        print("#################")
        print(Xpix, Ypix)

        # **************create tif**********************
        # print"start creating {0}".format(tifname)
        pBuf = None
        pBuf = lc.ReadAsArray(int(Xpix), int(Ypix), int(size), int(size))
        # print pBuf.dtype.name
        driver = gdal.GetDriverByName("GTiff")
        create_option = []
        if 'int8' in pBuf.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in pBuf.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
        outtif = os.path.join(dirpath, tifname)
        ds = driver.Create(outtif, int(size), int(size), int(bandscount), datatype, options=create_option)
        if ds == None:
            print("2222")
        ds.SetProjection(im_proj)
        ds.SetGeoTransform(newformtuple)
        ds.FlushCache()
        for i in range(int(bandscount)):
            outBand = ds.GetRasterBand(i + 1)
            outBand.WriteArray(pBuf[i])
        ds.FlushCache()
        # print "creating {0} successfully".format(tifname)
        # **************create shp**********************
        # print"start creating shps"
        gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
        gdal.SetConfigOption("SHAPE_ENCODING", "")
        strVectorFile = os.path.join(dirpath, shpname)
        ogr.RegisterAll()
        driver = ogr.GetDriverByName('ESRI Shapefile')
        ds = driver.Open(shp)
        layer0 = ds.GetLayerByIndex(0)
        prosrs = layer0.GetSpatialRef()
        # geosrs = osr.SpatialReference()

        oDriver = ogr.GetDriverByName("ESRI Shapefile")
        if oDriver == None:
            print("1")
            return

        oDS = oDriver.CreateDataSource(strVectorFile)
        if oDS == None:
            print("2")
            return

        papszLCO = []
        if "line" in sampletype or "LINE" in sampletype:
            oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO)
        if "poly" in sampletype or "POLY" in sampletype:
            oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO)
        if oLayer == None:
            print("3")
            return

        oFieldName = ogr.FieldDefn(fieldName, ogr.OFTString)
        oFieldName.SetWidth(50)
        oLayer.CreateField(oFieldName, 1)
        oDefn = oLayer.GetLayerDefn()
        oFeatureRectangle = ogr.Feature(oDefn)

        geomRectangle = ogr.CreateGeometryFromWkt(
            "POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1, y1, x2, y2, x4, y4, x3, y3))
        oFeatureRectangle.SetGeometry(geomRectangle)
        oLayer.CreateFeature(oFeatureRectangle)
        print("{0} ok".format(dirname))
        n = n + 1
        feature = layer.GetNextFeature()
    time2 = time.clock()
    print('Process Running time: %s min' % ((time2 - time1) / 60))

    return n


def mkdir(path):
    if not os.path.exists(path):
        os.mkdir(path)


if __name__ == "__main__":
    from shutil import copyfile

    outputdir = './plough'  # 输出路径
    mkdir(outputdir)
    sampletype = "line"  # 样本类型(线line或者面poly)
    size = 1000  # 样本大小
    n = 1  # 开始序号
    fieldName = 'cls'  # 字段名
    tif = './Level18/cq.tif'
    shp = 'train.shp'
    n = sampleClip(shp, tif, outputdir, sampletype, size, fieldName, n)
    print(n)

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/111187260