Pytorch DataLoader 读取tif(完整代码)

Python读取tif格式文件需要安装libtiff ,此外需要安装 inferno
本文适用于读取三维tif。

from torch.utils.data import DataLoader, Dataset
from inferno.io.transform.base import Transform, Compose
from inferno.io.transform.generic import Normalize, AsTorchBatch
from inferno.io.transform.image import RandomCrop, RandomRotate, RandomFlip 
from libtiff import TIFF
import os    
import torch
import numpy as np

定义一个MyDataSet

class MyDataSet(Dataset):
    def __init__(self, pathLst, transform): # Parameters and their form vary according to program needs
        dataPath, labelPath = pathLst    
        self.tifStreamData, self.tifStreamLabel = [], [] 
        dataFiles, labelFiles = os.listdir(dataPath), os.listdir(labelPath) 
        dataFiles.sort(key = lambda x: int(x[3:-4]))   #sorted by name order, such as LR_20.tif
        for dataFile in dataFiles:
            dataFileName = os.path.join(dataPath, dataFile)
            self.tifStreamData.append(tiff2Stack(dataFileName, transform))          
        
        labelFiles.sort(key = lambda x: int(x[3:-4]))  
        for labelFile in labelFiles:
            labelFileName = os.path.join(labelPath, labelFile)
            self.tifStreamLabel.append(tiff2Stack(labelFileName, transform))
        assert len(self.tifStreamData) == len(self.tifStreamLabel)    # check length 
            
    def __len__(self):
        return len(self.tifStreamData)
        
    def __getitem__(self,idx):
        data, label = self.tifStreamData, self.tifStreamLabel
        return data[idx], label[idx]
         
def tiff2Stack(fileName, transform=None):  # read tif, data transform, output tensor
    tif = TIFF.open(fileName,mode='r')
    tifLst = list(tif.iter_images()) # (51,101,101)
    tifArr = np.zeros((len(tifLst), tifLst[0].shape[0], tifLst[0].shape[1]))
    
    for i, img in enumerate(list(tif.iter_images())):
        tifArr[i,:,:] = img/1.0  # avoid that "can't convert np.ndarray of type numpy.uint16."
    if transform:
        tifArr = transform(tifArr)
    return tifArr  

调用

def main():
    transform = Compose(RandomRotate(), RandomFlip(), Normalize(), AsTorchBatch(2))
    pathLst = ["/your/tif/image/Data/path/", "/your/tif/image/Label/path/"]
    myTrainData = MyDataSet(pathLst, transform=transform)
    trainData = DataLoader(dataset=myTrainData, batch_size=4, shuffle=True)


    for i,j in enumerate(trainData):
        print(i)
        data, label = j
        print("data.shape",data.shape,"label.shape",label.shape)

if __name__ == "__main__":
    main()

猜你喜欢

转载自blog.csdn.net/qq_36937684/article/details/110357321