medical AI 实验室(一)

初识pytorch/cuda,项目为jupyter notebook实现

该实验一共包含四个部分,从pytorch入手,逐步实现医学影像的深度学习训练。

参考github:https://github.com/sixitingting/PyTorchMedicalAI

可以全部下载下来学习。本次是 shlomo_dl_0001_cuda_collab_pytorch.ipynb部分,项目简介: 这个实验主要是从cuda、gpu 信息查看;pytorch安装;pytorch 数据加载;数据增强方面入手。属于基础课。中途遇到不会的安装可以适当跳过,不要浪费太多时间。

-----------------------------------------------进入主题------------------------------------------------

目录

DataSets:

using C CUDA code from Python

Use this command to see GPU activity while doing Deep Learning tasks, for this command 'nvidia-smi' and for above one to work, go to 'Runtime > change runtime type > Hardware Accelerator > GPU

Memory footprint support libraries/code

002 PyTorch DataLoaders

The DataSet class ... PyTorch Datasets

Download the data

Export your Kaggle API Key

Now we will use the PyTorch DataSet, NOT simple image read

Augmentation / Transformations

Data Augmenting using PyTorch is quite easy.

Test the data loader

A simple Custom Data Augmentation


Author:

Shlomo Kashani, Head of AI at www.DeepOncology.AI[email protected]

Synopsys美国新思科技公司:

This is the hands-on deep learning tutorial series for the 2018/2019 Medical AI course. The series will guide you through the most basic building blocks such as installing CUDA to training advanced CNN's such as SeNet.

DataSets:

We foster the use of Medical Data Sets (https://grand-challenge.org/All_Challenges/) and predominantly those available (but not only) via Kaggle.

About PyTorch:

PyTorch is an open source library for numerical computation using computation graphs. Nodes in the graph represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) communicated between them.

Similar to python programming, we can add and execute a node to the computation graph immediately. This property makes it easy to debug the code and inspect the values in the network.

The other Notebooks in this series:

If you are familiar with these topics, feel free to jump to other modules.

using C CUDA code from Python

#@title Sample Header for Jupyter HW submission
#@markdown Forms support many types of fields.

your_kaggle_name = 'DeepOncology___'  #@param {type: "string"}
select_dataset = "seeds" #@param ["seeds", "dsb2018", "dsb2019"] {allow-input: true}
kaggle_acc_score = 77  # @param {type: "slider", min: 0, max: 100}
kaggle_log_loss_score = 75  # @param {type: "slider", min: 0, max: 100}
kaggle_IOU_score = 75  # @param {type: "slider", min: 0, max: 100}
date = '2010-11-05'  #@param {type: "date"}
pick_me = "monday"  #@param ['monday', 'tuesday', 'wednesday', 'thursday']

#@markdown ---

代码段一段一段运行,因为这是jupyter

%reset -f
# Do we have cuda?!
!which nvcc  
!nvcc --version

这段本是终端的命令,在jupyter 运行加个叹号就行了,神奇!

/usr/local/cuda-9.2/bin/nvcc
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Tue_Jun_12_23:07:04_CDT_2018
Cuda compilation tools, release 9.2, V9.2.148
# Let's check if a GPU accelerator card is attached in our machine:
!ls -l /dev/nv*

不认识的英文用有道翻译取词,很方面的,这些都能翻译准的。

crw-rw-rw- 1 root root 195,   0 Jan  2 17:14 /dev/nvidia0
crw-rw-rw- 1 root root 195,   1 Jan  2 17:14 /dev/nvidia1
crw-rw-rw- 1 root root 195, 255 Jan  2 17:14 /dev/nvidiactl
crw-rw-rw- 1 root root 243,   0 Jan  2 17:14 /dev/nvidia-uvm
!nvidia-smi -L
你的GPU列表
GPU 0: TITAN X (Pascal) (UUID: GPU-06cc647a-2c75-cece-58c0-eb8ca1196991)
GPU 1: TITAN X (Pascal) (UUID: GPU-6b92c2ec-dd63-9b17-fb35-c4512ca64b08)

Use this command to see GPU activity while doing Deep Learning tasks, for this command 'nvidia-smi' and for above one to work, go to 'Runtime > change runtime type > Hardware Accelerator > GPU

! nvidia-smi

!lscpu |grep 'Model name'
Model name:            Intel(R) Xeon(R) CPU E5-2680 v4 @ 2.40GHz
#no.of threads each core is having
!lscpu | grep 'Thread(s) per core'
Thread(s) per core:    1
#memory that we can use
!cat /proc/meminfo | grep 'MemAvailable'
MemAvailable:   80832892 kB
#hard disk that we can use
!df -h / | awk '{print $4}'
Avail
30G
import os
os.environ['PATH'] += ':/usr/local/cuda/bin'

# This magic will create a new CU file
# To write a CUDA C program, we need to:

#Create a source code file with the special file name extension of .cu.
#Compile the program using the CUDA nvcc compiler.
#Run the executable file from the command line, which contains the kernel code executable on the GPU.

%%file version.cu
#include <thrust/version.h>
#include <iostream>

int main(void)
{
  int major = THRUST_MAJOR_VERSION;
  int minor = THRUST_MINOR_VERSION;

  std::cout << "Thrust v" << major << "." << minor << std::endl;

  return 0;
}

# nvcc is the CUDA compiler 
!nvcc version.cu -o version
!./version

这一段是cvcc编程,我没搞出来,直接跳过

Memory footprint support libraries/code

!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install gputil
!pip install psutil
!pip install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " I Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()
Requirement already satisfied: gputil in /root/anaconda3/lib/python3.6/site-packages (1.4.0)
Requirement already satisfied: psutil in /root/anaconda3/lib/python3.6/site-packages (5.4.5)
Requirement already satisfied: humanize in /root/anaconda3/lib/python3.6/site-packages (0.5.1)
[]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-22-7434c070a6a4> in <module>()
     10 print(GPUs)
     11 # XXX: only one GPU on Colab and isn’t guaranteed
---> 12 gpu = GPUs[0]
     13 def printm():
     14  process = psutil.Process(os.getpid())

IndexError: list index out of range

我gpu是空,所以后面有报错

002 PyTorch DataLoaders

We are going to create a data loader for a neutral network to classify images.
We are NOT going to run the neural network; just to load data.

import sys
sys.version
'3.6.5 |Anaconda, Inc.| (default, Apr 29 2018, 16:14:56) \n[GCC 7.2.0]'
import torch

如果没有,请安装

# !pip3 install torch==0.4
# !pip3 install torchvision

!pip3 install 'torch==0.4.0'
!pip3 install 'torchvision==0.2.1'
!pip3 install --no-cache-dir -I 'pillow==5.1.0'

# Restart Kernel
# This workaround is needed to properly upgrade PIL on Google Colab.
import os
os._exit(00)

安装慢的话,可以通过链接下载下来,在终端用pip安装。我是终端自行安装

Import PyTorch once again

import matplotlib.pyplot as plt
import time
from shutil import copyfile
from os.path import isfile, join, abspath, exists, isdir, expanduser
from os import listdir, makedirs, getcwd, remove
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
import pandas as pd
import numpy as np
import torch
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as func
import torchvision
from torchvision import transforms, datasets, models
import random 

Let's print the versions

import sys
print('__Python VERSION:', sys.version)
print('__pyTorch VERSION:', torch.__version__)
print('__CUDA VERSION')
from subprocess import call
# call(["nvcc", "--version"]) does not work
! nvcc --version
print('__CUDNN VERSION:', torch.backends.cudnn.version())
print('__Number CUDA Devices:', torch.cuda.device_count())
print('__Devices')
# call(["nvidia-smi", "--format=csv", "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"])
print('Active CUDA Device: GPU', torch.cuda.current_device())

print ('Available devices ', torch.cuda.device_count())
print ('Current cuda device ', torch.cuda.current_device())

use_cuda = torch.cuda.is_available()
# use_cuda = False

print("USE CUDA=" + str (use_cuda))
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
Tensor = FloatTensor
__Python VERSION: 3.6.5 |Anaconda, Inc.| (default, Apr 29 2018, 16:14:56) 
[GCC 7.2.0]
__pyTorch VERSION: 0.4.0
__CUDA VERSION
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Tue_Jun_12_23:07:04_CDT_2018
Cuda compilation tools, release 9.2, V9.2.148
__CUDNN VERSION: 7102
__Number CUDA Devices: 2
__Devices
Active CUDA Device: GPU 0
Available devices  2
Current cuda device  0
USE CUDA=True

Fixing the Random Seed

manualSeed = 2222
def fixSeed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_cuda:
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


if manualSeed is None:
        manualSeed = 999
fixSeed(manualSeed)

The DataSet class ... PyTorch Datasets

To create a dataset, we subclass Dataset and define a constructor, a len method, and a getitem method. Here is full example:

Type Markdown and LaTeX: α2

from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as func
import torchvision
from torchvision import transforms, datasets, models
import random
from shutil import copyfile
from os.path import isfile, join, abspath, exists, isdir, expanduser
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
import time
from shutil import copyfile
from os.path import isfile, join, abspath, exists, isdir, expanduser
from os import listdir, makedirs, getcwd, remove
from mpl_toolkits.axes_grid1 import ImageGrid
import pandas as pd
import numpy as np
import torch
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as func
import torchvision
from torchvision import transforms, datasets, models
import random
import sys
from glob import glob
import fnmatch

class GenericDataset(Dataset):
  def __init__(self, labels, root_dir, subset=False, transform=None):
    self.labels = labels
    self.root_dir = root_dir
    self.transform = transform

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    img_name = self.labels.iloc[idx, 0]  # file name  第idx行,0列的内容
    fullname = join(self.root_dir, img_name)
    image = Image.open(fullname).convert('RGB')
    labels = self.labels.iloc[idx, 2]  # category_id
    #         print (labels)
    if self.transform:
      image = self.transform(image)
    return image, labels

  @staticmethod
  def find_classes(fullDir):
    classes = [d for d in os.listdir(fullDir) if os.path.isdir(os.path.join(fullDir, d))]
    # 后面数据是种子数据,classes = black-grass\ charlock\ cleavers\ and so on
    classes.sort()  # 排序
    class_to_idx = {classes[i]: i for i in range(len(classes))}  # 创建类别对应序号123.....
    num_to_class = dict(zip(range(len(classes)), classes))

    train = []
    for index, label in enumerate(classes):
      path = fullDir + label + '/'
      for file in listdir(path):
        train.append(['{}/{}'.format(label, file), label, index])

    df = pd.DataFrame(train, columns=['file', 'category', 'category_id', ])

    return classes, class_to_idx, num_to_class, df

Download the data

Kaggle 162 H&E

This dataset consists of 5547 breast histology images of size 50 x 50 x 3, curated from Andrew Janowczyk website and used for a data science tutorial at Epidemium. The goal is to classify cancerous images (IDC : invasive ductal carcinoma) vs non-IDC images. 我们的目标是将癌性图像(IDC:浸润性导管癌)与非IDC图像进行分类

Download:

https://www.kaggle.com/simjeg/lymphoma-subtype-classification-fl-vs-cll or http://andrewjanowczyk.com/wp-static/IDC_regular_ps50_idx5.zip

Download:

https://www.kaggle.com/c/plant-seedlings-classification kaggle competitions download -c plant-seedlings-classification

data_dir= '/home/data/bone/train/' data_dir= '/home/data/bone/train/

看不懂往下走,把数据下载下来

# kaggle competitions download -c plant-seedlings-classification
!wget http://andrewjanowczyk.com/wp-static/IDC_regular_ps50_idx5.zip

这个数据下载下来后面没有用到

Export your Kaggle API Key

!mkdir -p ~/.kaggle
%%file ~/.kaggle/kaggle.json
  
{"username":"solomonk","key":"af21d853c5e242e7d4c3e0e6a588309b"}
Writing /root/.kaggle/kaggle.json
# from google.colab import drive
# drive.mount('/content/gdrive')
!pip install kaggle --upgrade
!kaggle competitions download -c plant-seedlings-classification

用的是这个种子数据,可以自行去kaggle下载,有以下文件就行

sample_submission.csv.zip
test.zip
train.zip
plant-seedlings-classification
! ls -la
total 3377220
drwxr-xr-x 6 root root      12288 Jan  5 13:03 .
drwxr-xr-x 3 root root       4096 Jan  3 20:33 ..
drwxr-xr-x 2 root root       4096 Jan  3 20:10 assets
-rw-r--r-- 1 root root 1644892042 Jan  5 12:28 IDC_regular_ps50_idx5.zip
drwxr-xr-x 2 root root       4096 Jan  3 20:10 .idea
drwxr-xr-x 2 root root       4096 Jan  3 20:10 .ipynb_checkpoints
-rw-r--r-- 1 root root    1045244 Jan  5 12:31 plant-seedlings-classification
-rw-r--r-- 1 root root       6920 Jan  3 20:10 README.md
-rw-r--r-- 1 root root       5251 Jan  5 12:28 sample_submission.csv.zip
drwxr-xr-x 2 root root       4096 Jan  5 12:32 seeds
-rw-r--r-- 1 root root    1012909 Jan  5 13:03 shlomo_dl_0001_cuda_collab_pytorch.ipynb
-rw-r--r-- 1 root root      19454 Jan  3 20:10 shlomo_dl_0002_tensors_collab_pytorch.ipynb
-rw-r--r-- 1 root root      57193 Jan  3 20:10 shlomo_dl_0003_pytorch_ffn_mnist.ipynb
-rw-r--r-- 1 root root     376106 Jan  3 20:10 shlomo_dl_0006_pretrained_cnn_collab_pytorch.ipynb
-rw-r--r-- 1 root root   88080384 Jan  5 12:28 test.zip
-rw-r--r-- 1 root root 1718530069 Jan  5 13:01 train.zip
-rw-r--r-- 1 root root    4185119 Jan  5 02:20 wget-log
!mkdir seeds
!mv train.zip seeds/train.zip
!unzip seeds/train.zip 
! ls seeds
train.zip
!ls train
Black-grass  Common Chickweed  Loose Silky-bent   Shepherds Purse
Charlock     Common wheat      Maize		  Small-flowered Cranesbill
Cleavers     Fat Hen	       Scentless Mayweed  Sugar beet

可以去文件夹查看你的文件有没有

import os
from glob import glob
from matplotlib.pyplot import imshow
import numpy as np
from PIL import Image

dataset='train/' # 
data_dir= './' +  dataset
# !pip install --no-cache-dir -I pillow
# !pip install Pillow==4.0.0
# !pip install PIL
# !pip install image

import matplotlib.pyplot as plt
%matplotlib inline

# import matplotlib as mpl
# mpl.rcParams['axes.grid'] = False
# mpl.rcParams['image.interpolation'] = 'nearest'
# mpl.rcParams['figure.figsize'] = 15, 25

imageList = glob(data_dir + '/**/*.png', recursive=True)  # 匹配文件,**代表通配,返回列表
print ( "Number of images: {}". format (len (imageList)))
for img in imageList[0:5]:
    print(img)
    
%matplotlib inline
pil_im = Image.open(imageList[56], 'r')
imshow(np.asarray(pil_im))
Number of images: 4750
./train/Loose Silky-bent/5a60a6eb4.png
./train/Loose Silky-bent/5dbd18569.png
./train/Loose Silky-bent/eafe89ea6.png
./train/Loose Silky-bent/b39cf3ed0.png
./train/Loose Silky-bent/8c796e67b.png

Out[11]:

<matplotlib.image.AxesImage at 0x7f70a7de37f0>

i_ = 0
plt.rcParams['figure.figsize'] = (10.0, 10.0)
plt.subplots_adjust(wspace=0, hspace=0)
for l in imageList[:15]:
    pil_im = Image.open(l, 'r')        
    plt.subplot(5, 5, i_+1) #.set_title(l)
    plt.imshow(np.asarray(pil_im)); 
    plt.axis('off')
    i_ += 1

Now we will use the PyTorch DataSet, NOT simple image read

classes, class_to_idx, num_to_class, df =GenericDataset.find_classes(data_dir )

print (classes)
print (class_to_idx)
print (num_to_class)
df.head(5)
['Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat', 'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet']
{'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}
{0: 'Black-grass', 1: 'Charlock', 2: 'Cleavers', 3: 'Common Chickweed', 4: 'Common wheat', 5: 'Fat Hen', 6: 'Loose Silky-bent', 7: 'Maize', 8: 'Scentless Mayweed', 9: 'Shepherds Purse', 10: 'Small-flowered Cranesbill', 11: 'Sugar beet'}

Target distribution : Train set

import seaborn as sns   # 画散点图工具
sns.set(color_codes=True)  
pal = sns.color_palette()
sns.set_style("whitegrid")

labels = df['category'].apply(lambda x: x.split(' '))
from collections import Counter, defaultdict
counts = defaultdict(int)
for l in labels:
    for l2 in l:
        counts[l2] += 1

counts_df = pd.DataFrame.from_dict(counts, orient='index')
counts_df.columns = ['count']
counts_df.sort_values('count', ascending=False, inplace=True)

fig, ax = plt.subplots()
ax = sns.barplot(x=counts_df.index, y=counts_df['count'], ax=ax)
fig.set_size_inches(20,8)
ax.set_xticklabels(ax.xaxis.get_majorticklabels(), rotation=-45);

Augmentation / Transformations

  • The train dataset includes data augmentation techniques such as cropping to size 224 and horizontal flips.The train and validiation datasets are normalized with mean: [0.485, 0.456, 0.406], and standard deviation: [0.229, 0.224, 0.225].

  • These values are the means and standard deviations of the ImageNet images.

  • We use these values because usually pretrained models are trained on ImageNet

  • from __future__ import absolute_import
    
    from torchvision.transforms import *
    
    from PIL import Image, ImageDraw
    import numpy as np
    import torch
    
    import torchvision
    import random
    from PIL import Image, ImageOps
    import numpy as np
    import numbers
    import math
    import torch
    import torch
    import random
    import PIL.ImageEnhance as ie
    import PIL.Image as im
    
    # adapted from https://github.com/kuangliu/pytorch-retinanet/blob/master/transform.py
    # https://github.com/mratsim/Amazon-Forest-Computer-Vision/blob/master/src/p_data_augmentation.py
    
    normalize_img = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    

    Data Augmenting using PyTorch is quite easy.

  • We can use the transforms provided in torchvision: torchvision.transforms.

  • To compose more than a few transforms together, we use torchvision.transforms.Compose and provide the transforms as a list.

  • The transforms are applied following the list order.

Central note:

  •  For training a CNN, the transforms.ToTensor is used to convert the images to a PyTorch Tensor and transforms.Normalize to normalize the images according to the pre trained network that you will train. We are omitting these steps since we focus on augmenting images 

  • image_size = 224
    
    normalize_img = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    
    
    train_trans = transforms.Compose([
        transforms.RandomSizedCrop(image_size),
        transforms.ToTensor(),
    ])
    
    ## Normalization only for validation and test
    valid_trans = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
    #     normalize_img
    ])
    
    batch_size = 8
    train_data = df.sample(frac=0.85)
    valid_data = df[~df['file'].isin(train_data['file'])]
    
    train_set = GenericDataset(train_data, data_dir, transform = train_trans)
    valid_set = GenericDataset(valid_data, data_dir, transform = valid_trans)
            
    
    t_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    v_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0)
    # test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=sys.cpu_count -1)
    
    dataset_sizes = {
        'train': len(t_loader.dataset), 
        'valid': len(v_loader.dataset)
    }
    
    
    print (dataset_sizes)
    print (train_data["category_id"].value_counts())
    
    train_data['category_id'].value_counts().plot(kind='bar')
    
{'train': 4038, 'valid': 712}
6     566
3     513
8     445
10    421
5     408
1     327
11    316
2     247
0     225
9     202
7     185
4     183
Name: category_id, dtype: int64

valid_data['category_id'].value_counts().plot(kind='bar')

Test the data loader

imagesToShow=4

def flaotTensorToImage(img, mean=0, std=1):
        """convert a tensor to an image"""
        img = np.transpose(img.numpy(), (1, 2, 0))
        img = (img*std+ mean)*255
        img = img.astype(np.uint8)    
        return img    

if __name__ == '__main__':  
    for i, data in enumerate(t_loader, 0):
        print('i=%d: '%(i))            
        images, labels = data            
        num = len(images)

        ax = plt.subplot(1, imagesToShow, i + 1)
        plt.tight_layout()
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')

        for n in range(num):
            image=images[n]
            label=labels[n]
            plt.imshow (flaotTensorToImage(image))
        if i==imagesToShow-1:
            break

A simple Custom Data Augmentation

The augmentations available on PyTorch are simple. What if we want to perform more interesting augmentations?

Let's use - https://github.com/zhunzhong07/Random-Erasing to achieve that.随机擦除

class RandomErasing(object):
    def __init__(self, EPSILON = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
        self.EPSILON = EPSILON
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
       
    def __call__(self, img):

        if random.uniform(0, 1) > self.EPSILON:
            return img

        for attempt in range(100):
            area = img.size()[1] * img.size()[2]
       
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1/self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w <= img.size()[2] and h <= img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    #img[0, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    #img[1, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    #img[2, x1:x1+h, y1:y1+w] = random.uniform(0, 1)
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
                    #img[:, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(3, h, w))
                else:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[1]
                    # img[0, x1:x1+h, y1:y1+w] = torch.from_numpy(np.random.rand(1, h, w))
                return img

        return img
image_size = 224

normalize_img = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])


train_trans = transforms.Compose([
    transforms.RandomSizedCrop(image_size),
#     PowerPIL(),
    transforms.ToTensor(),
#     normalize_img,
    RandomErasing()
])

## Normalization only for validation and test
valid_trans = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
#     normalize_img
])

batch_size = 8
train_data = df.sample(frac=0.85)
valid_data = df[~df['file'].isin(train_data['file'])]

train_set = GenericDataset(train_data, data_dir, transform = train_trans)
valid_set = GenericDataset(valid_data, data_dir, transform = valid_trans)
        

t_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
v_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0)
# test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

dataset_sizes = {
    'train': len(t_loader.dataset), 
    'valid': len(v_loader.dataset)
}


print (dataset_sizes)
print (train_data["category_id"].value_counts())


imagesToShow=4

def flaotTensorToImage(img, mean=0, std=1):
        """convert a tensor to an image"""
        img = np.transpose(img.numpy(), (1, 2, 0))
        img = (img*std+ mean)*255
        img = img.astype(np.uint8)    
        return img    

if __name__ == '__main__':  
    for i, data in enumerate(t_loader, 0):
        print('i=%d: '%(i))            
        images, labels = data            
        num = len(images)

        ax = plt.subplot(1, imagesToShow, i + 1)
        plt.tight_layout()
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')

        for n in range(num):
            image=images[n]
            label=labels[n]
            plt.imshow (flaotTensorToImage(image))

        if i==imagesToShow-1:
            break

猜你喜欢

转载自blog.csdn.net/u014264373/article/details/85851050