一、合并波段
from osgeo import gdal
import os
import numpy as np
#读取RGBtif文件函数
def readRGBTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return
# im_width = dataset.RasterXSize #栅格矩阵的列数
# im_height = dataset.RasterYSize #栅格矩阵的行数
im_width = 6060 #栅格矩阵的列数
im_height = 6060 #栅格矩阵的行数
im_bands = dataset.RasterCount #波段数
im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
im_proj = dataset.GetProjection()#获取投影信息
im_blueBand = im_data[0,0:im_height,0:im_width]#获取蓝波段
im_greenBand = im_data[1,0:im_height,0:im_width]#获取绿波段
im_redBand = im_data[2,0:im_height,0:im_width]#获取红波段
#im_nirBand = im_data[3,0:im_height,0:im_width]#获取近红外波段
im_dtype = im_data.dtype.name
return im_data, im_width,im_height, im_bands, im_geotrans ,im_proj,im_dtype, im_blueBand, im_greenBand, im_redBand
#读取DEMtif文件函数
def readDEMTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return
im_width_or = dataset.RasterXSize #栅格矩阵的列数
im_height_or = dataset.RasterYSize #栅格矩阵的行数
im_width = 6060 #栅格矩阵的列数
im_height = 6060 #栅格矩阵的行数
if im_width_or<im_width or im_height_or<im_height:
im_data=dataset.ReadAsArray(0,0,im_width_or,im_height_or)
im_data=np.pad(im_data,(0,100),'constant')
im_data=im_data[:6060,:6060]
else:
im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
# im_data=np.round(im_data)
# n_max=np.max(im_data)
# n_min=np.min(im_data)
# im_data=(im_data-n_min)/(n_max-n_min)*255
im_bands = dataset.RasterCount #波段数
im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
im_proj = dataset.GetProjection()#获取投影信息
im_DEMBand = im_data[0:im_height,0:im_width]#获取蓝波段
im_dtype = im_data.dtype.name
return im_data, im_width,im_height, im_bands, im_geotrans ,im_proj,im_dtype, im_DEMBand
#保存tif文件函数
def writeTiff(im_data,im_width,im_height,im_bands,im_geotrans,im_proj,path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
else:
im_bands, (im_height, im_width) = 1,im_data.shape
#创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
dataset.SetProjection(im_proj) #写入投影
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def get_file_names(data_dir, file_type = ['tif','tiff']):
result_dir = []
result_name = []
for maindir, subdir, file_name_list in os.walk(data_dir):
for filename in file_name_list:
apath = maindir+'/'+filename
ext = apath.split('.')[-1]
if ext in file_type:
result_dir.append(apath)
result_name.append(filename)
else:
pass
return result_dir, result_name
in_dir1 = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/sat_train'
in_dir2 = '/mnt/sdb1/fenghaixia/dsm/dataset/all/sat_train'
out_dir = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp'
file_type = 'tif'
data_dir_list1,_ = get_file_names(in_dir1, file_type)
data_dir_list2,_ = get_file_names(in_dir2, file_type)
#data_dir_list = data_dir_list1 + data_dir_list2
for each_index, each_dir in enumerate(data_dir_list1):
print(in_dir2+'/'+each_dir.split('/')[-1])
if os.path.exists(in_dir2+'/'+each_dir.split('/')[-1]):
img1, width1, height1, bands1, geotrans1, proj1,dtype1, blueband, greenband, redband = readRGBTif(each_dir)
img2, width2, height2, bands2, geotrans2, proj2,dtype2, DEMband = readDEMTif(in_dir2+'/'+each_dir.split('/')[-1])
DEMband[DEMband>100]=0
DEMband=DEMband*2.55
print(each_dir)
print(dtype1)
if 'int8' in dtype1:
datatype = gdal.GDT_Byte
elif 'int16' in dtype1:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
driver = gdal.GetDriverByName("GTiff")
# print(type(driver))
each_out_dir = out_dir + '/' + each_dir.split('/')[-1]
#each_out_dir = 'C:/Users/Dell/Desktop/guigang/trans_6bands.tif'
#datatype = 'uint8'
print('each_out_dir: ', each_out_dir)
new_dataset = driver.Create(each_out_dir, width1, height1, bands1+bands2, gdal.GDT_Byte)
print(type(new_dataset))
#print(each_out_dir)
new_dataset.SetGeoTransform(geotrans1)
new_dataset.SetProjection(proj1)
new_dataset.GetRasterBand(1).WriteArray(DEMband[0])
new_dataset.GetRasterBand(2).WriteArray(DEMband[1])
new_dataset.GetRasterBand(3).WriteArray(DEMband[2])
new_dataset.GetRasterBand(4).WriteArray(blueband)
new_dataset.GetRasterBand(5).WriteArray(greenband)
new_dataset.GetRasterBand(6).WriteArray(redband)
new_dataset.FlushCache()
del new_dataset
print('combine over')
二、裁剪
裁剪sat
import os
from osgeo import gdal
inPath = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp/'
outPath= '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp2/'
for f in os.listdir(inPath):
# 读取要切的原图
imgPath=inPath+f.strip()
in_ds = gdal.Open(imgPath)
print(f.strip()+"open tif file succeed")
width = in_ds.RasterXSize # 获取数据宽度
height = in_ds.RasterYSize # 获取数据高度
outbandsize = in_ds.RasterCount # 获取数据波段数
# im_geotrans = in_ds.GetGeoTransform() # 获取仿射矩阵信息
# im_proj = in_ds.GetProjection() # 获取投影信息
datatype = in_ds.GetRasterBand(1).DataType
im_data = in_ds.ReadAsArray() # 获取数据
# 读取原图中的每个波段
in_band1 = in_ds.GetRasterBand(1)
in_band2 = in_ds.GetRasterBand(2)
in_band3 = in_ds.GetRasterBand(3)
in_band4 = in_ds.GetRasterBand(4)
# 定义切图的起始点坐标
offset_x = 0
offset_y = 0
# 定义切图的大小(矩形框)
block_xsize = 512 # 行
block_ysize = 512 # 列
k = 0
for j in range(width // block_xsize):
for i in range(height // block_xsize):
out_band1 = in_band1.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
out_band2 = in_band2.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
out_band3 = in_band3.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
out_band4 = in_band4.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
# print(out_band3)
# 获取Tif的驱动,为创建切出来的图文件做准备
gtif_driver = gdal.GetDriverByName("GTiff")
# 创建切出来的要存的文件(3代表3个不都按,最后一个参数为数据类型,跟原文件一致)
filename = outPath + f.strip()[:-4] + '({},{})@{:04d}_sat.tif'.format(j, i, k) # 文件名称
k += 1
out_ds = gtif_driver.Create(filename, block_xsize, block_ysize, outbandsize, datatype)
# print("create new tif file succeed")
# 写入目标文件
out_ds.GetRasterBand(1).WriteArray(out_band1)
out_ds.GetRasterBand(2).WriteArray(out_band2)
out_ds.GetRasterBand(3).WriteArray(out_band3)
out_ds.GetRasterBand(4).WriteArray(out_band4)
# 将缓存写入磁盘
out_ds.FlushCache()
del out_ds
print("FlushCache succeed")
print("End!")
mask多的图
import os
import cv2
# source = 'dataset/sat_train/'
real_path ="/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp/"
pre_path ="/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/mask_train/"
real_names = filter(lambda x: x.find('tif')!=-1, os.listdir(real_path))
pre_names = filter(lambda x: x.find('tif')!=-1, os.listdir(pre_path))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in pre_names:
real_name = real_path + f.strip()
if not os.path.exists(real_name):
# os.remove(pre_path + f.strip())
print(pre_path + f.strip())
裁剪mask
import cv2
import os
# Cutting the input image to h*w blocks
inPath2 = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/mask_train/'
outPath= '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp2/'
for f in os.listdir(inPath2):
path = inPath2 + f.strip()
print(path)
img = cv2.imread(path)
height = img.shape[0]
width = img.shape[1]
# The size of block that you want to cut
heightBlock = 512
widthBlock = 512
heightCutNum = int(height / heightBlock)
widthCutNum = int(width / widthBlock)
l = 0
for i in range(0,heightCutNum):
for j in range(0,widthCutNum):
cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_mask.png'.format(i, j, l)
l+=1
cv2.imwrite(savePath,cutImage)
print(savePath)
print("finish!")
mask_names = filter(lambda x: x.find('mask')!=-1, os.listdir(outPath))
# sat_names = filter(lambda x: x.find('sat')!=-1, os.listdir(tar))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in mask_names:
path = outPath + f.strip()
if not os.path.exists(path):
continue;
img = cv2.imread(path,0)
if cv2.countNonZero(img) == 0:
print(f+'Image is black')
path2=f[:-9]
os.remove(path)
os.remove(outPath +path2 + "_sat.tif")
in_ds = gdal.Open(os.path.join(root,'{}_sat.tif').format(id))
img = in_ds.ReadAsArray()
img=img.transpose(1,2,0)
三、测试