1. 前言
本博文介绍的脚本,能够较为方便在指定区域批量地将遥感影像裁剪成固定大小的切片。
2. 样本准备
影像以及对应的点矢量
3. 基于gdal的裁剪代码
# -*- coding: utf-8 -*-
from osgeo import ogr
import os, sys
import numpy as np
import cv2
import numpy
import gdal
import time
import glob
from osgeo import osr
def del_file(path):
for i in os.listdir(path):
path_file = os.path.join(path, i)
if os.path.isfile(path_file):
os.remove(path_file)
else:
del_file(path_file)
def sampleClip(shp,tif,outputdir,sampletype,size, fieldName='cls', n=None):
time1 = time.clock()
# if not os.path.exists(outputdir):
# os.mkdir(outputdir)
# else:
# del_file(outputdir)
gdal.AllRegister()
lc = gdal.Open(tif)
im_width = lc.RasterXSize
im_height = lc.RasterYSize
im_geotrans = lc.GetGeoTransform()
bandscount=lc.RasterCount
im_proj = lc.GetProjection()
print(im_width,im_height)
gdal.AllRegister()
gdal.SetConfigOption("gdal_FILENAME_IS_UTF8","YES")
driver = ogr.GetDriverByName('ESRI Shapefile')
dsshp = driver.Open(shp, 0)
if dsshp is None:
print('Could not open ' + 'sites.shp')
sys.exit(1)
layer = dsshp.GetLayer()
xValues = []
yValues = []
m = layer.GetFeatureCount()
feature = layer.GetNextFeature()
print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount,m,sampletype,int(size)))
if n is not None:
pass
else:
n=1
while feature:
dirname = "0000000" + str(n)
#print dirname
dirpath=os.path.join(outputdir,dirname+"_V1")
if not os.path.exists(dirpath):
os.mkdir(dirpath)
tifname=dirname+".tif"
if "poly" in sampletype or "POLY" in sampletype:
shpname=dirname+"_V1_POLY.shp"
if "line" in sampletype or "LINE" in sampletype:
shpname=dirname+"_V1_LINE.shp"
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
print(x,y)
print(im_geotrans)
xValues.append(x)
yValues.append(y)
newform=[]
newform=list(im_geotrans)
#print newform
newform[0]=x-im_geotrans[1]*int(size)/2.0
newform[3]=y-im_geotrans[5]*int(size)/2.0
print(newform[0],newform[3])
newformtuple=tuple(newform)
x1=x-int(size)/2*im_geotrans[1]
y1=y-int(size)/2*im_geotrans[5]
x2=x+int(size)/2*im_geotrans[1]
y2=y-int(size)/2*im_geotrans[5]
x3=x-int(size)/2*im_geotrans[1]
y3=y+int(size)/2*im_geotrans[5]
x4=x+int(size)/2*im_geotrans[1]
y4=y+int(size)/2*im_geotrans[5]
Xpix=(x1-im_geotrans[0])/im_geotrans[1]
#Xpix=(newform[0]-im_geotrans[0])
Ypix=(newform[3]-im_geotrans[3])/im_geotrans[5]
#Ypix=abs(newform[3]-im_geotrans[3])
print("#################")
print(Xpix,Ypix)
#**************create tif**********************
#print"start creating {0}".format(tifname)
pBuf=None
pBuf=lc.ReadAsArray(int(Xpix),int(Ypix),int(size),int(size))
#print pBuf.dtype.name
driver = gdal.GetDriverByName("GTiff")
create_option=[]
if 'int8' in pBuf.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in pBuf.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
outtif = os.path.join(dirpath,tifname)
ds = driver.Create(outtif,int(size),int(size),int(bandscount),datatype,options=create_option)
if ds==None:
print("2222")
ds.SetProjection(im_proj)
ds.SetGeoTransform(newformtuple)
ds.FlushCache()
for i in range(int(bandscount)):
outBand = ds.GetRasterBand(i + 1)
outBand.WriteArray(pBuf[i])
ds.FlushCache()
#print "creating {0} successfully".format(tifname)
#**************create shp**********************
#print"start creating shps"
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
gdal.SetConfigOption("SHAPE_ENCODING", "")
strVectorFile =os.path.join(dirpath, shpname)
ogr.RegisterAll()
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp)
layer0 = ds.GetLayerByIndex(0)
prosrs = layer0.GetSpatialRef()
#geosrs = osr.SpatialReference()
oDriver = ogr.GetDriverByName("ESRI Shapefile")
if oDriver == None:
print("1")
return
oDS =oDriver.CreateDataSource(strVectorFile)
if oDS == None:
print("2")
return
papszLCO = []
if "line" in sampletype or "LINE" in sampletype:
oLayer =oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO)
if "poly" in sampletype or "POLY" in sampletype:
oLayer =oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO)
if oLayer == None:
print("3")
return
oFieldName =ogr.FieldDefn(fieldName, ogr.OFTString)
oFieldName.SetWidth(50)
oLayer.CreateField(oFieldName, 1)
oDefn = oLayer.GetLayerDefn()
oFeatureRectangle = ogr.Feature(oDefn)
geomRectangle = ogr.CreateGeometryFromWkt("POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1,y1,x2,y2,x4,y4,x3,y3))
oFeatureRectangle.SetGeometry(geomRectangle)
oLayer.CreateFeature(oFeatureRectangle)
print("{0} ok".format(dirname))
n = n+1
feature = layer.GetNextFeature()
time2 = time.clock()
print('Process Running time: %s min'%((time2-time1)/60))
return n
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
if __name__=="__main__":
tifList = glob.glob('D:/weitu/download/heli*/Level18/*.tif') # 影像列表
outputdir = 'D:/2020/jf_project/hardshand3' #输出路径
mkdir(outputdir)
sampletype = "poly" #样本类型(线line或者面poly)
size = 600 #样本大小
n = 1 # 开始序号
fieldName = 'cls' # 字段名
for tif in tifList:
subRoot = os.path.split(tif)[0]
shp = glob.glob(f'{subRoot}/*.shp')[0]
assert os.path.exists(shp), 'check you shp file'
n = sampleClip(shp, tif, outputdir, sampletype, size, fieldName, n)
4. 效果预览