pytorch 7月24日学习---pix2pix代码学习1

一. 库和参数

import argparse

import os

import numpy as np

import math

import itertools

import time

import datetime

import sys

import torchvision.transforms as transforms

from torchvision.utils import save_image

from torch.utils.data import DataLoader

from torchvision import datasets

from torch.autograd import Variable

from models import *

from datasets import *

import torch.nn as nn

import torch.nn.functional as F

import torch

1. itertools库

迭代器(生成器)在Python中是一种很常用也很好用的数据结构,比起列表(list)来说,迭代器最大的优势就是延迟计算,按需使用,从而提高开发体验和运行效率,以至于在Python 3中map,filter等操作返回的不再是列表而是迭代器。

2. datetime库

datetime.date.today() 打印输出当前的系统日期

datetime.date.fromtimestamp(time.time()) 将时间戳转成日期格式

datetime.datetime.now() 打印当前的系统时间

current_time.replace(2016,5,12) 返回当前时间,但指定的值将被替换 datetime.datetime.strptime(“21/11/06 16:30”, “%d/%m/%y %H:%M”) 将字符串转换成日期格式

参数设置

parser = argparse.ArgumentParser()

parser.add_argument('--epoch', type=int, default=0, help='epoch to start training from')

parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')

parser.add_argument('--dataset_name', type=str, default="facades", help='name of the dataset')

parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')

parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')

parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')

parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')

parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay')

parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')

parser.add_argument('--img_height', type=int, default=256, help='size of image height')

parser.add_argument('--img_width', type=int, default=256, help='size of image width')

parser.add_argument('--channels', type=int, default=3, help='number of image channels')

parser.add_argument('--sample_interval', type=int, default=500, help='interval between sampling of images from generators')

parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between model checkpoints')

opt = parser.parse_args()

print(opt)

二. Loss函数

# Loss functions

criterion_GAN = torch.nn.MSELoss()

criterion_pixelwise = torch.nn.L1Loss()

# Loss weight of L1 pixel-wise loss between translated image and real image

lambda_pixel = 100

1. torch.nn.MSELoss()

均方损失函数

公式为:loss(xi,yi)=(xi−yi)2

>>> loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)

>>> input = torch.autograd.Variable(torch.randn(3,4))

>>> target = torch.autograd.Variable(torch.randn(3,4))

>>> loss = loss_fn(input, target)

>>> print(input); print(target); print(loss)

tensor([[-1.3524,  0.5194,  1.0586,  0.1549],

        [ 1.6697,  0.4262,  0.0257,  1.1458],

        [ 0.6460,  0.3691,  0.5229, -2.1614]])

tensor([[-0.8691,  1.7308,  1.0579,  0.2359],

        [-0.3626, -0.7589, -0.0547,  0.9764],

        [ 0.3606, -0.5090, -0.9875, -0.6050]])

tensor([[ 2.3352e-01,  1.4676e+00,  4.3082e-07,  6.5639e-03],

        [ 4.1302e+00,  1.4044e+00,  6.4638e-03,  2.8681e-02],

        [ 8.1454e-02,  7.7096e-01,  2.2813e+00,  2.4222e+00]])

2. torch.nn.L1Loss()

公式为:loss(xi,yi)=|xi−yi|

>>> import torch

>>> loss_fn = torch.nn.L1Loss(reduce=False, size_average=False)

>>> input = torch.autograd.Variable(torch.randn(3,4))

>>> target = torch.autograd.Variable(torch.randn(3,4))

>>> loss = loss_fn(input, target)

>>> print(input); print(target); print(loss)

tensor([[-0.2028,  1.0140, -0.9712,  1.6227],

        [ 1.0678, -1.3599, -1.1543,  1.6353],

        [-0.1146, -0.2229,  0.1262, -0.8661]])

tensor([[-0.7508, -0.7450,  0.0223, -0.8037],

        [ 1.3009,  0.3976, -0.3933,  0.6665],

        [ 0.0281,  1.9780, -1.6017, -1.6238]])

tensor([[ 0.5479,  1.7590,  0.9935,  2.4265],

        [ 0.2331,  1.7575,  0.7610,  0.9688],

        [ 0.1427,  2.2009,  1.7279,  0.7577]])

三. 载入模型

# Initialize generator and discriminator

generator = GeneratorUNet()

discriminator = Discriminator()

if opt.epoch != 0:

    # Load pretrained models

    generator.load_state_dict(torch.load('saved_models/%s/generator_%d.pth' % (opt.dataset_name, opt.epoch)))

    discriminator.load_state_dict(torch.load('saved_models/%s/discriminator_%d.pth' % (opt.dataset_name, opt.epoch)))

else:

    # Initialize weights

    generator.apply(weights_init_normal)

    discriminator.apply(weights_init_normal)

# Optimizers

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

pytorch 提供了 state_dict() 和 load_state_dict() 两个参数用来保存和加载模型参数, 前者将模型参数保存为字典形式, 后者将字典形式的模型参数载入到模型当中. 

1. 首先, 读取当前模型参数

model_dict = model.state_dict()

2. 读取预训练模型, 并选取要保留的部分

pre_dict = torch.load('path')

pre_dict = {k: v for k, v in pre_dict.items() if k in model_dict}

3. 使用预训练的模型更新当前模型参数

model_dict.update(pre_dict)

4. 加载模型参数

model.load_state_dict(model_dict)

generator.load_state_dict(torch.load('saved_models/%s/generator_%d.pth' % (opt.dataset_name, opt.epoch)))

该句将步骤结合在一起

四. 配置dataloaders

# Configure dataloaders

transforms_ = [ transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),

                transforms.ToTensor(),

                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]



dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),

                        batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)



val_dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode='val'),

                            batch_size=10, shuffle=True, num_workers=1)

五. 图片保存

def sample_images(batches_done):

    """Saves a generated sample from the validation set"""

    imgs = next(iter(val_dataloader))

    real_A = Variable(imgs['B'].type(Tensor))

    real_B = Variable(imgs['A'].type(Tensor))

    fake_B = generator(real_A)

    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)

save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)

第24000次图片:

源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/pix2pix

猜你喜欢

转载自blog.csdn.net/weixin_42445501/article/details/81204663