仅个人记录:rgbd运行

 一、合并波段

from osgeo import gdal
import os
import numpy as np
#读取RGBtif文件函数
def readRGBTif(fileName):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName+"文件无法打开")
        return
    # im_width = dataset.RasterXSize #栅格矩阵的列数
    # im_height = dataset.RasterYSize #栅格矩阵的行数
    im_width = 6060 #栅格矩阵的列数
    im_height = 6060 #栅格矩阵的行数
    im_bands = dataset.RasterCount #波段数
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
    im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
    im_proj = dataset.GetProjection()#获取投影信息
    im_blueBand =  im_data[0,0:im_height,0:im_width]#获取蓝波段
    im_greenBand = im_data[1,0:im_height,0:im_width]#获取绿波段
    im_redBand =   im_data[2,0:im_height,0:im_width]#获取红波段
    #im_nirBand = im_data[3,0:im_height,0:im_width]#获取近红外波段
    im_dtype = im_data.dtype.name
    return im_data, im_width,im_height, im_bands, im_geotrans ,im_proj,im_dtype, im_blueBand, im_greenBand, im_redBand
 
#读取DEMtif文件函数
def readDEMTif(fileName):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName+"文件无法打开")
        return
    im_width_or = dataset.RasterXSize #栅格矩阵的列数
    im_height_or = dataset.RasterYSize #栅格矩阵的行数
    im_width = 6060 #栅格矩阵的列数
    im_height = 6060 #栅格矩阵的行数
    if im_width_or<im_width or im_height_or<im_height:
        im_data=dataset.ReadAsArray(0,0,im_width_or,im_height_or)
        im_data=np.pad(im_data,(0,100),'constant')
        im_data=im_data[:6060,:6060]
    else:
        im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获取数据
   
 
    # im_data=np.round(im_data)
    # n_max=np.max(im_data)
    # n_min=np.min(im_data)
    # im_data=(im_data-n_min)/(n_max-n_min)*255
    im_bands = dataset.RasterCount #波段数
    im_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
    im_proj = dataset.GetProjection()#获取投影信息
    im_DEMBand =  im_data[0:im_height,0:im_width]#获取蓝波段
 
    im_dtype = im_data.dtype.name
    return im_data, im_width,im_height, im_bands, im_geotrans ,im_proj,im_dtype, im_DEMBand
 
#保存tif文件函数
 
def writeTiff(im_data,im_width,im_height,im_bands,im_geotrans,im_proj,path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
 
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape
        #创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if(dataset!= None):
        dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
        dataset.SetProjection(im_proj) #写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i+1).WriteArray(im_data[i])
    del dataset
 
def get_file_names(data_dir, file_type = ['tif','tiff']):
    result_dir = [] 
    result_name = []
    for maindir, subdir, file_name_list in os.walk(data_dir):
        for filename in file_name_list:
            apath = maindir+'/'+filename
            ext = apath.split('.')[-1]  
            if ext in file_type:
                result_dir.append(apath)
                result_name.append(filename)
            else:
                pass
    return result_dir, result_name
 
 
in_dir1 = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/sat_train'
in_dir2 = '/mnt/sdb1/fenghaixia/dsm/dataset/all/sat_train'
out_dir = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp'

file_type = 'tif'
data_dir_list1,_ = get_file_names(in_dir1, file_type)
data_dir_list2,_ = get_file_names(in_dir2, file_type)
#data_dir_list = data_dir_list1 + data_dir_list2
 
for each_index, each_dir in enumerate(data_dir_list1):
    print(in_dir2+'/'+each_dir.split('/')[-1])
    if os.path.exists(in_dir2+'/'+each_dir.split('/')[-1]):
        img1, width1, height1, bands1, geotrans1, proj1,dtype1, blueband, greenband, redband = readRGBTif(each_dir)
        img2, width2, height2, bands2, geotrans2, proj2,dtype2, DEMband = readDEMTif(in_dir2+'/'+each_dir.split('/')[-1])
        DEMband[DEMband>100]=0
        DEMband=DEMband*2.55
        print(each_dir)
        print(dtype1)
        if 'int8' in dtype1:
            datatype = gdal.GDT_Byte      
        elif 'int16' in dtype1:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32
            
                
        driver = gdal.GetDriverByName("GTiff")
    # print(type(driver))
        each_out_dir = out_dir + '/' + each_dir.split('/')[-1]
        #each_out_dir = 'C:/Users/Dell/Desktop/guigang/trans_6bands.tif'
        #datatype = 'uint8'
        print('each_out_dir: ', each_out_dir)
        new_dataset = driver.Create(each_out_dir, width1, height1, bands1+bands2, gdal.GDT_Byte)
        print(type(new_dataset))
        #print(each_out_dir)
        new_dataset.SetGeoTransform(geotrans1)
        new_dataset.SetProjection(proj1)
        
        new_dataset.GetRasterBand(1).WriteArray(DEMband[0])
        new_dataset.GetRasterBand(2).WriteArray(DEMband[1])
        new_dataset.GetRasterBand(3).WriteArray(DEMband[2])
        new_dataset.GetRasterBand(4).WriteArray(blueband)
        new_dataset.GetRasterBand(5).WriteArray(greenband)
        new_dataset.GetRasterBand(6).WriteArray(redband)
        new_dataset.FlushCache()
        del new_dataset
        print('combine over')

二、裁剪

裁剪sat



 
import os
from osgeo import gdal
 
inPath = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp/'
outPath=  '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp2/'
 
 
for f in os.listdir(inPath):
# 读取要切的原图
    imgPath=inPath+f.strip()
    in_ds = gdal.Open(imgPath)
    print(f.strip()+"open tif file succeed")
    width = in_ds.RasterXSize  # 获取数据宽度
    height = in_ds.RasterYSize  # 获取数据高度
    outbandsize = in_ds.RasterCount # 获取数据波段数
    # im_geotrans = in_ds.GetGeoTransform() # 获取仿射矩阵信息
    # im_proj = in_ds.GetProjection()  # 获取投影信息
    datatype = in_ds.GetRasterBand(1).DataType
    im_data = in_ds.ReadAsArray()  # 获取数据
 
    # 读取原图中的每个波段
    in_band1 = in_ds.GetRasterBand(1)
    in_band2 = in_ds.GetRasterBand(2)
    in_band3 = in_ds.GetRasterBand(3)
    in_band4 = in_ds.GetRasterBand(4)
 
    # 定义切图的起始点坐标
    offset_x = 0
    offset_y = 0
 
    # 定义切图的大小(矩形框)
    block_xsize = 512 # 行
    block_ysize = 512 # 列
 
 
    k = 0
    for j in range(width // block_xsize):
        for i in range(height // block_xsize):
            out_band1 = in_band1.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
            out_band2 = in_band2.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
            out_band3 = in_band3.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
            out_band4 = in_band4.ReadAsArray(i * block_xsize, j * block_xsize, block_xsize, block_ysize)
            # print(out_band3)
            # 获取Tif的驱动,为创建切出来的图文件做准备
            gtif_driver = gdal.GetDriverByName("GTiff")
 
 
            # 创建切出来的要存的文件(3代表3个不都按,最后一个参数为数据类型,跟原文件一致)
            filename = outPath + f.strip()[:-4] + '({},{})@{:04d}_sat.tif'.format(j, i, k)  # 文件名称
            k += 1
            out_ds = gtif_driver.Create(filename, block_xsize, block_ysize, outbandsize, datatype)
            # print("create new tif file succeed")
 
            # 写入目标文件
            out_ds.GetRasterBand(1).WriteArray(out_band1)
            out_ds.GetRasterBand(2).WriteArray(out_band2)
            out_ds.GetRasterBand(3).WriteArray(out_band3)
            out_ds.GetRasterBand(4).WriteArray(out_band4)
 
            # 将缓存写入磁盘
            out_ds.FlushCache()
            del out_ds
    print("FlushCache succeed")
print("End!")

mask多的图 

import os
import cv2
# source = 'dataset/sat_train/'
real_path ="/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp/"
pre_path ="/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/mask_train/"
 
real_names = filter(lambda x: x.find('tif')!=-1, os.listdir(real_path))
pre_names = filter(lambda x: x.find('tif')!=-1, os.listdir(pre_path))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in pre_names:
    real_name = real_path + f.strip()
    if not os.path.exists(real_name):
        # os.remove(pre_path + f.strip())
        print(pre_path + f.strip())

裁剪mask 


import cv2
import os
 
# Cutting the input image to h*w blocks
 
inPath2 = '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/mask_train/'
outPath=  '/mnt/sdb1/fenghaixia/DeepGlobe-Road-Extraction-link34-py3_test_all/dataset/a/tmp2/'
 

for f in os.listdir(inPath2):
    path = inPath2 + f.strip()
    print(path)
    img = cv2.imread(path) 
    height = img.shape[0]
    width = img.shape[1]
    # The size of block that you want to cut
    heightBlock = 512
    widthBlock = 512
    heightCutNum = int(height / heightBlock)
    widthCutNum = int(width / widthBlock)
    l = 0
    for i in range(0,heightCutNum):
        for j in range(0,widthCutNum):
            cutImage = img[i*heightBlock:(i+1)*heightBlock, j*widthBlock:(j+1)*widthBlock]
            savePath = outPath + f.strip()[:-4]+'({},{})@{:04d}_mask.png'.format(i, j, l)
            l+=1
            cv2.imwrite(savePath,cutImage)
            print(savePath)
print("finish!") 
 
mask_names = filter(lambda x: x.find('mask')!=-1, os.listdir(outPath))
# sat_names = filter(lambda x: x.find('sat')!=-1, os.listdir(tar))
#trainlist = list(map(lambda x: x[:-8], imagelist))
for f in mask_names:
    path = outPath + f.strip()
    if not os.path.exists(path):
        continue;    
    img = cv2.imread(path,0)             
    if cv2.countNonZero(img) == 0:
       print(f+'Image is black')
       path2=f[:-9]
       os.remove(path)
       os.remove(outPath +path2 + "_sat.tif")

 

    in_ds = gdal.Open(os.path.join(root,'{}_sat.tif').format(id))
    img = in_ds.ReadAsArray()
    img=img.transpose(1,2,0)

 三、测试

猜你喜欢

转载自blog.csdn.net/weixin_61235989/article/details/130203279