图像补全(又称图像修复)旨在用视觉上逼真和语义上适当的内容来填充图像的缺失部分,被广泛应用于图像处理领域。卷积神经网络由于其强大的纹理建模能力,在计算机视觉领域取得了巨大进展,然而卷积神经网络在理解全局结构上表现不佳,近几年transformer的发展证明了其在建模长期关系(全局结构)方面的能力,但transformer的计算复杂度阻碍了其在处理高分辨率图像中的应用。ICT将这两种方法的优点结合到图像补全任务当中,先使用transformer进行外观先验重建,恢复了多元相干结构和一些粗糙纹理,再用卷积神经网络进行纹理补充,增强了由高分辨率掩模图像引导的粗糙先验的局部纹理细节,其基本原理如下图所示。
如果你对MindSpore感兴趣,可以关注昇思MindSpore社区
一、环境准备
1.进入ModelArts官网
云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网
选择下方CodeLab立即体验
等待环境搭建完成
2.使用CodeLab体验Notebook实例
下载NoteBook样例代码,Mask-RCNN实现实例分割 ,.ipynb
为样例代码
选择ModelArts Upload Files上传.ipynb
文件
选择Kernel环境
切换至GPU环境,切换成第一个限时免费
进入昇思MindSpore官网,点击上方的安装
获取安装命令
回到Notebook中,在第一块代码前加入命令
conda update -n base -c defaults conda
安装MindSpore 2.0 GPU版本
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
安装mindvision
pip install mindvision
安装下载download
pip install download
二、环境准备
ImageNet数据集
在下面的教程中,我们将通过示例代码说明如何设置网络、优化器、以及计算损失函数。在本教程中,我们使用ImageNet数据集进行训练与推理,可自行前往ImageNet数据集官网进行下载,解压后使文件呈如下目录结构:
# ImageNet dataset
.imagenet/
├── train/ (1000 directories and 1281167 images)
├── n04347754/
│ ├── 000001.jpg
│ ├── 000002.jpg
│ └── ....
└── n04347756/
├── 000001.jpg
├── 000002.jpg
└── ....
├── val/ (1000 directories and 50000 images)
├── n04347754/
│ ├── 000001.jpg
│ ├── 000002.jpg
│ └── ....
└── n04347756/
├── 000001.jpg
├── 000002.jpg
└── ....
Mask掩码数据集
在图像补全任务中,我们还需要掩码数据集来对图像进行mask处理,以得到部分像素损坏的图像,掩码数据集可自行前往下载,解压后使文件呈如下目录结构:
# Mask dataset
.mask/
├── testing_mask_dataset/
├── 000001.png
├── 000002.png
├── 000003.png
└── ....
参数设置
首先设置训练与推理相关参数,并选择执行模式,本案例默认使用动态图运行模式进行推理与训练。
值得注意的是,在训练阶段依赖VGG19的模型,权重文件VGG19.ckpt需要提前进行下载,并在参数中保持地址一致。
import os
import argparse
import mindspore
import numpy as np
from mindspore import context
def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser()
# Parameter of train
parser.add_argument('--data_path', type=str, default='/data0/imagenet2012/train',
help='Indicate where is the training set')
parser.add_argument('--mask_path', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
help='Where is the mask')
parser.add_argument('--n_layer', type=int, default=35)
parser.add_argument('--n_head', type=int, default=8)
parser.add_argument('--n_embd', type=int, default=1024)
parser.add_argument('--GELU_2', action='store_true', help='use the new activation function')
parser.add_argument('--use_ImageFolder', action='store_true', help='Using the original folder for ImageNet dataset')
parser.add_argument('--random_stroke', action='store_true', help='Use the generated mask')
parser.add_argument('--train_epoch', type=int, default=5, help='How many epochs')
parser.add_argument('--learning_rate', type=float, default=3e-4, help='Value of learning rate.')
parser.add_argument('--input', type=str, default='/data0/imagenet2012/train',
help='path to the input images directory or an input image')
parser.add_argument('--mask', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
help='path to the masks directory or a mask file')
parser.add_argument('--prior', type=str, default='', help='path to the edges directory or an edge file')
parser.add_argument('--kmeans', type=str, default='../kmeans_centers.npy', help='path to the kmeans')
# 根据VGG19.ckpt的路径,要进行相应修改,VGG19权重文件可在推理部分给出的权重地址进行下载
parser.add_argument('--vgg_path', type=str, default='../ckpts_ICT/VGG19.ckpt', help='path to the VGG')
parser.add_argument('--image_size', type=int, default=256, help='the size of origin image')
parser.add_argument('--prior_random_degree', type=int, default=1, help='during training, how far deviate from')
parser.add_argument('--use_degradation_2', action='store_true', help='use the new degradation function')
parser.add_argument('--mode', type=int, default=1, help='1:train, 2:test')
parser.add_argument('--mask_type', type=int, default=2)
parser.add_argument('--max_iteration', type=int, default=25000, help='How many run iteration')
parser.add_argument('--lr', type=float, default=0.0001, help='Value of learning rate.')
parser.add_argument('--D2G_lr', type=float, default=0.1,
help='Value of discriminator/generator learning rate ratio')
parser.add_argument("--l1_loss_weight", type=float, default=1.0)
parser.add_argument("--style_loss_weight", type=float, default=250.0)
parser.add_argument("--content_loss_weight", type=float, default=0.1)
parser.add_argument("--inpaint_adv_loss_weight", type=float, default=0.1)
parser.add_argument('--ckpt_path', type=str, default='', help='model checkpoints path')
parser.add_argument('--save_path', type=str, default='./checkpoint', help='save checkpoints path')
parser.add_argument('--device_id', type=int, default='0')
parser.add_argument('--device_target', type=str, default='GPU', help='GPU or Ascend')
parser.add_argument('--prior_size', type=int, default=32, help='Input sequence length = prior_size*prior_size')
parser.add_argument('--batch_size', type=int, default=2, help='The number of train batch size')
parser.add_argument("--beta1", type=float, default=0.9, help="Value of beta1")
parser.add_argument("--beta2", type=float, default=0.95, help="Value of beta2")
# Parameter of infer
parser.add_argument('--image_url', type=str, default='/data0/imagenet2012/val', help='the folder of image')
parser.add_argument('--mask_url', type=str, default='/home/ict/ICT/mask/testing_mask_dataset',
help='the folder of mask')
parser.add_argument('--top_k', type=int, default=40)
parser.add_argument('--save_url', type=str, default='./sample', help='save the output results')
parser.add_argument('--condition_num', type=int, default=1, help='Use how many BERT output')
args = parser.parse_known_args()[0]
return args
# 选择执行模式为动态图模式,执行硬件平台为GPU
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
opts = parse_args()
训练
Transformer训练
Transformer网络结构
由于Transformer计算量偏大,因此在图片输入网络前会先对其进行下采样,在ImageNet数据中,图片会被采样到32*32分辨率,以得到图像低分辨率的图像先验信息,Transformer的结构如下所示:

以下是Transformer的代码实现:
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore.common.initializer import Normal
class CausalSelfAttention(nn.Cell):
"""
The CausalSelfAttention part of transformer.
Args:
n_embd (int): The size of the vector space in which words are embedded.
n_head (int): The number of multi-head.
block_size (int): The context size(Input sequence length).
resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
attn_pdrop (float): The probability of attn_pdrop. Default: 0.1
Returns:
Tensor, output tensor.
"""
def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
super().__init__()
# key, query, value projections for all heads
self.key = nn.Dense(in_channels=n_embd, out_channels=n_embd,
weight_init=Normal(sigma=0.02, mean=0.0))
self.query = nn.Dense(in_channels=n_embd, out_channels=n_embd,
weight_init=Normal(sigma=0.02, mean=0.0))
self.value = nn.Dense(in_channels=n_embd, out_channels=n_embd,
weight_init=Normal(sigma=0.02, mean=0.0))
# regularization
self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_pdrop)
self.resid_drop = nn.Dropout(keep_prob=1.0 - resid_pdrop)
# output projection
self.proj = nn.Dense(in_channels=n_embd, out_channels=n_embd,
weight_init=Normal(sigma=0.02, mean=0.0))
tril = nn.Tril()
self.mask = mindspore.Parameter(
tril(P.Ones()((block_size, block_size), mindspore.float32)).view(1, 1, block_size, block_size),
requires_grad=False)
self.n_head = n_head
def construct(self, x):
B, T, C = P.Shape()(x)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head)
k = k.transpose(0, 2, 1, 3)
q = self.query(x).view(B, T, self.n_head, C // self.n_head)
q = q.transpose(0, 2, 1, 3)
v = self.value(x).view(B, T, self.n_head, C // self.n_head)
v = v.transpose(0, 2, 1, 3)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
k_shape = k.shape[-1]
sz = 1.0 / (Tensor(k_shape) ** Tensor(0.5))
q = mindspore.ops.Cast()(q, mindspore.float16)
k = mindspore.ops.Cast()(k, mindspore.float16)
sz = mindspore.ops.Cast()(sz, mindspore.float16)
att = (ops.matmul(q, k.transpose(0, 1, 3, 2)) * sz)
att = mindspore.ops.Cast()(att, mindspore.float32)
att = P.Softmax()(att)
att = self.attn_drop(att)
att = mindspore.ops.Cast()(att, mindspore.float16)
v = mindspore.ops.Cast()(v, mindspore.float16)
y = mindspore.ops.matmul(att, v)
y = mindspore.ops.Cast()(y, mindspore.float32)
y = y.transpose(0, 2, 1, 3).view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class GELU2(nn.Cell):
"""
The new gelu2 activation function.
Returns:
Tensor, output tensor.
"""
def construct(self, x):
return x * P.Sigmoid()(1.702 * x)
class Block_2(nn.Cell):
"""
Transformer block with original GELU2.
Args:
n_embd (int): The size of the vector space in which words are embedded.
n_head (int): The number of multi-head.
block_size (int): The context size(Input sequence length).
resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
attn_pdrop (float): The probability of attn_pdrop. Default: 0.1
Returns:
Tensor, output tensor.
"""
def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
self.ln2 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
self.attn = CausalSelfAttention(n_embd, n_head, block_size, resid_pdrop, attn_pdrop)
self.mlp = nn.SequentialCell([
nn.Dense(in_channels=n_embd, out_channels=4 * n_embd,
weight_init=Normal(sigma=0.02, mean=0.0)),
GELU2(),
nn.Dense(in_channels=4 * n_embd, out_channels=n_embd,
weight_init=Normal(sigma=0.02, mean=0.0)),
nn.Dropout(keep_prob=1.0 - resid_pdrop),
])
def construct(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class Block(nn.Cell):
"""
Transformer block with original GELU.
Args:
n_embd (int): The size of the vector space in which words are embedded.
n_head (int): The number of multi-head.
block_size (int): The context size(Input sequence length).
resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
attn_pdrop (float): The probability of attn_pdrop. Default: 0.1
Returns:
Tensor, output tensor.
"""
def __init__(self, n_embd: int, n_head: int, block_size: int, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
self.ln2 = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
self.attn = CausalSelfAttention(n_embd, n_head, block_size, resid_pdrop, attn_pdrop)
self.mlp = nn.SequentialCell([
nn.Dense(in_channels=n_embd, out_channels=4 * n_embd,
weight_init=Normal(sigma=0.02, mean=0.0)),
nn.GELU(),
nn.Dense(in_channels=4 * n_embd, out_channels=n_embd,
weight_init=Normal(sigma=0.02, mean=0.0)),
nn.Dropout(keep_prob=1.0 - resid_pdrop),
])
def construct(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Cell):
"""
The full GPT language model, with a context size of block_size.
Args:
vocab_size (int): The size of the vocabulary in the embedded data.
n_embd (int): The size of the vector space in which words are embedded.
n_layer (int): The number of attention layer.
n_head (int): The number of multi-head.
block_size (int): The context size(Input sequence length).
use_gelu2 (bool): Use the new gelu2 activation function.
embd_pdrop (float): The probability of embd_pdrop. Default: 0.1
resid_pdrop (float): The probability of resid_pdrop. Default: 0.1
attn_pdrop (float): The probability of attn_pdrop. Default: 0.1
Returns:
Tensor, output tensor.
"""
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, block_size: int, use_gelu2: bool,
embd_pdrop: float = 0.1, resid_pdrop: float = 0.1, attn_pdrop: float = 0.1):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, n_embd, embedding_table=Normal(sigma=0.02, mean=0.0))
self.pos_emb = mindspore.Parameter(P.Zeros()((1, block_size, n_embd), mindspore.float32))
self.drop = nn.Dropout(keep_prob=1.0 - embd_pdrop)
# transformer
if use_gelu2:
self.blocks = nn.SequentialCell(
[*[Block_2(n_embd, n_head, block_size, resid_pdrop, attn_pdrop) for _ in range(n_layer)]])
else:
self.blocks = nn.SequentialCell(
[*[Block(n_embd, n_head, block_size, resid_pdrop, attn_pdrop) for _ in range(n_layer)]])
# decoder head
self.ln_f = nn.LayerNorm(normalized_shape=[n_embd], epsilon=1e-05)
self.head = nn.Dense(in_channels=n_embd, out_channels=vocab_size, has_bias=False,
weight_init=Normal(sigma=0.02, mean=0.0))
self.block_size = block_size
def get_block_size(self):
return self.block_size
def construct(self, idx, masks):
_, t = idx.shape
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
masks = P.ExpandDims()(masks, 2)
token_embeddings = token_embeddings * (1 - masks)
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits
定义损失函数
class TransformerWithLoss(nn.Cell):
"""
Wrap the network with loss function to return Transformer with loss.
Args:
backbone (Cell): The target network to wrap.
"""
def __init__(self, backbone):
super(TransformerWithLoss, self).__init__(auto_prefix=False)
self.backbone = backbone
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
def construct(self, x, targets, masks):
logits = self.backbone(x, masks)
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), targets.view(-1))
masks = P.ExpandDims()(masks, 2)
masks = masks.view(-1)
loss *= masks
loss = P.ReduceMean()(loss)
return loss
定义数据集以及模型
# 修改当前路径至ICT/Transformer
if os.getcwd().endswith('ICT') or os.getcwd().endswith('ict'):
os.chdir("./Transformer")
# 启动ImageNet数据集选项
opts.use_ImageFolder = True
from datasets.dataset import load_dataset
from transformer_utils.util import AverageMeter
kmeans = np.load('../kmeans_centers.npy')
kmeans = np.rint(127.5 * (kmeans + 1.0))
# Define the dataset
train_dataset = load_dataset(opts.data_path, kmeans, mask_path=opts.mask_path, is_train=True,
use_imagefolder=opts.use_ImageFolder, prior_size=opts.prior_size,
random_stroke=opts.random_stroke)
train_dataset = train_dataset.batch(opts.batch_size)
step_size = train_dataset.get_dataset_size()
# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=kmeans.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
model = TransformerWithLoss(backbone=transformer)
# Define the optimizer
optimizer = nn.Adam(model.trainable_params(), learning_rate=opts.learning_rate, beta1=opts.beta1, beta2=opts.beta2)
train_net = nn.TrainOneStepCell(model, optimizer)
train_loss = AverageMeter()
best_loss = 10000000000.
if not os.path.exists(opts.save_path):
os.mkdir(opts.save_path)
opts.train_epoch = 1
for epoch in range(opts.train_epoch):
train_loss.reset()
for i, sample in enumerate(train_dataset.create_dict_iterator()):
x = sample['data']
y = sample['mask']
y = mindspore.ops.Cast()(y, mindspore.float32)
loss = train_net(x, x, y)
train_loss.update(loss, 1)
if i % 2 == 0:
print(f"Epoch: [{
epoch} / {
opts.train_epoch}], "
f"step: [{
i} / {
step_size}], "
f"loss: {
train_loss.avg}")
if train_loss.avg < best_loss:
best_loss = train_loss.avg
mindspore.save_checkpoint(transformer, os.path.join(opts.save_path, 'ImageNet_best.ckpt'))
if i != 0 and i % 100 == 0:
break
mindspore.save_checkpoint(transformer, os.path.join(opts.save_path, 'ImageNet_latest.ckpt'))
class Generator(nn.Cell):
def __init__(self, residual_blocks=8):
super(Generator, self).__init__()
self.encoder = nn.SequentialCell(
nn.Pad(((0, 0), (0, 0), (3, 3), (3, 3)), mode='REFLECT'),
nn.Conv2d(in_channels=6, out_channels=64, kernel_size=7, pad_mode='pad', padding=0, has_bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1,
has_bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, pad_mode='pad', padding=1,
has_bias=True),
nn.ReLU()
)
blocks = []
for _ in range(residual_blocks):
block = ResnetBlock_remove_IN(256, 2)
blocks.append(block)
self.middle = nn.SequentialCell(*blocks)
self.decoder = nn.SequentialCell(
nn.Conv2dTranspose(in_channels=256, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1,
has_bias=True),
nn.ReLU(),
nn.Conv2dTranspose(in_channels=128, out_channels=64, kernel_size=4, stride=2, pad_mode='pad', padding=1,
has_bias=True),
nn.ReLU(),
nn.Pad(((0, 0), (0, 0), (3, 3), (3, 3)), mode='REFLECT'),
nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, pad_mode='pad', padding=0, has_bias=True),
)
def construct(self, images, edges, masks):
images_masked = (images * P.Cast()((1 - masks), mindspore.float32)) + masks
x = P.Concat(axis=1)((images_masked, edges))
x = self.encoder(x)
x = self.middle(x)
x = self.decoder(x)
x = (P.Tanh()(x) + 1) / 2
return x
class ResnetBlock_remove_IN(nn.Cell):
def __init__(self, dim, dilation=1):
super(ResnetBlock_remove_IN, self).__init__()
self.conv_block = nn.SequentialCell([
nn.Pad(((0, 0), (0, 0), (dilation, dilation), (dilation, dilation)), mode='REFLECT'),
nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, pad_mode='pad', dilation=dilation,
has_bias=True),
nn.ReLU(),
nn.Pad(((0, 0), (0, 0), (1, 1), (1, 1)), mode='SYMMETRIC'),
nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, pad_mode='pad', dilation=1, has_bias=True)
])
def construct(self, x):
out = x + self.conv_block(x)
return out
class Discriminator(nn.Cell):
def __init__(self, in_channels, use_sigmoid=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
self.conv1 = nn.SequentialCell(
nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, pad_mode='pad', padding=1),
nn.LeakyReLU()
)
self.conv2 = nn.SequentialCell(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, pad_mode='pad', padding=1),
nn.LeakyReLU()
)
self.conv3 = nn.SequentialCell(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, pad_mode='pad', padding=1),
nn.LeakyReLU()
)
self.conv4 = nn.SequentialCell(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, pad_mode='pad', padding=1),
nn.LeakyReLU()
)
self.conv5 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, pad_mode='pad', padding=1)
def construct(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
outputs = conv5
if self.use_sigmoid:
outputs = P.Sigmoid()(conv5)
return outputs, [conv1, conv2, conv3, conv4, conv5]
定义数据集与网络模型
# 修改当前路径至ICT/Guided_Upsample
if os.getcwd().endswith('Transformer'):
os.chdir("../Guided_Upsample")
elif os.getcwd().endswith('ICT') or os.getcwd().endswith('ict'):
os.chdir("./Guided_Upsample")
from Guided_Upsample.datasets.dataset import load_dataset
from Guided_Upsample.models.loss import GeneratorWithLoss, DiscriminatorWithLoss
train_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
kmeans=opts.kmeans, use_degradation_2=opts.use_degradation_2,
prior_random_degree=opts.prior_random_degree,
augment=True, training=True)
train_dataset = train_dataset.batch(opts.batch_size)
step_size = train_dataset.get_dataset_size()
generator = Generator()
discriminator = Discriminator(in_channels=3)
model_G = GeneratorWithLoss(generator, discriminator, opts.vgg_path, opts.inpaint_adv_loss_weight, opts.l1_loss_weight,
opts.content_loss_weight, opts.style_loss_weight)
model_D = DiscriminatorWithLoss(generator, discriminator)
定义优化器与损失函数
class PSNR(nn.Cell):
def __init__(self, max_val):
super(PSNR, self).__init__()
base10 = P.Log()(mindspore.Tensor(10.0, mindspore.float32))
max_val = P.Cast()(mindspore.Tensor(max_val), mindspore.float32)
self.base10 = mindspore.Parameter(base10, requires_grad=False)
self.max_val = mindspore.Parameter(20 * P.Log()(max_val) / base10, requires_grad=False)
def __call__(self, a, b):
a = P.Cast()(a, mindspore.float32)
b = P.Cast()(b, mindspore.float32)
mse = P.ReduceMean()((a - b) ** 2)
if mse == 0:
return mindspore.Tensor(0)
return self.max_val - 10 * P.Log()(mse) / self.base10
psnr_func = PSNR(255.0)
# Define the optimizer
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=opts.lr, beta1=opts.beta1, beta2=opts.beta2)
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=opts.lr * opts.D2G_lr, beta1=opts.beta1,
beta2=opts.beta2)
from Guided_Upsample.models.loss import TrainOneStepCell
from Guided_Upsample.upsample_utils.util import postprocess
train_net = TrainOneStepCell(model_G, model_D, optimizer_G, optimizer_D)
epoch = 0
iteration = 0
keep_training = True
psnr = AverageMeter()
mae = AverageMeter()
psnr_best = 0.0
opts.max_iteration = 100
if not os.path.exists(opts.save_path):
os.mkdir(opts.save_path)
while keep_training:
epoch += 1
for i, sample in enumerate(train_dataset.create_dict_iterator()):
images = sample['images']
edges = sample['edges']
masks = sample['masks']
loss_D, loss_G = train_net(images, edges, masks)
iteration += 1
if i % 2 == 0:
outputs = generator(images, edges, masks)
outputs_merged = (outputs * masks) + (images * (1 - masks))
psnr.update(psnr_func(postprocess(images), postprocess(outputs_merged)), 1)
mae.update((P.ReduceSum()(P.Abs()(images - outputs_merged)) / P.ReduceSum()(images)), 1)
print(f"Epoch: [{
epoch}], "
f"step: [{
i} / {
step_size}], "
f"psnr: {
psnr.avg}, mae: {
mae.avg}")
if psnr.avg > psnr_best:
psnr_best = psnr.avg
mindspore.save_checkpoint(generator, os.path.join(opts.save_path, 'InpaintingModel_gen_best.ckpt'))
mindspore.save_checkpoint(discriminator, os.path.join(opts.save_path, 'InpaintingModel_dis_best.ckpt'))
if iteration >= opts.max_iteration:
keep_training = False
break
mindspore.save_checkpoint(generator, os.path.join(opts.save_path, 'InpaintingModel_gen_latest.ckpt'))
mindspore.save_checkpoint(discriminator, os.path.join(opts.save_path, 'InpaintingModel_dis_latest.ckpt'))
推理
下面进行所训练的模型在ImageNet验证集集上的推理,由于ImageNet数据集规模过大,在此案例中从ImageNet数据集中挑选几张图片作为案例推理测试使用。
模型训练好的权重可以在此处ict_ckpt进行下载,下载完毕后放到ckpts_ICT下,其中权重文件目录如下所示,其中origin是由pytorch模型转换而来的权重文件,ms_train则是由MindSpore框架训练的权重文件:
.ckpts_ICT/
├── origin
├── Transformer/
│ ├── ImageNet.ckpt
│ ├── FFHQ.ckpt
│ └── Places2_Nature.ckpt
└── Upsample/
├── FFHQ
│ ├──InpaintingModel_dis.ckpt
│ └──InpaintingModel_gen.ckpt
├── ImageNet
│ ├──InpaintingModel_dis.ckpt
│ └──InpaintingModel_gen.ckpt
└── Places2_Nature
├──InpaintingModel_dis.ckpt
└──InpaintingModel_gen.ckpt
├── ms_train
├── Transformer/
│ └── ImageNet_best.ckpt
└── Upsample/
├── InpaintingModel_dis_best.ckpt
└── InpaintingModel_gen_best.ckpt
├── VGG19.ckpt
import numpy as np
import mindspore
if not os.getcwd().endswith('ICT') or not os.getcwd().endswith('ict'):
os.chdir("..")
opts = parse_args()
input_path = './input'
if not os.path.exists(input_path):
os.mkdir(input_path)
if os.path.exists(opts.image_url):
opts.image_url = os.path.join(opts.image_url, 'n01440764')
os.system('cp {} {}'.format(os.path.join(opts.image_url, 'ILSVRC2012_val_00000293.JPEG'), input_path))
os.system('cp {} {}'.format(os.path.join(opts.image_url, 'ILSVRC2012_val_00002138.JPEG'), input_path))
opts.image_url = input_path
C = np.load('./kmeans_centers.npy')
C = np.rint(127.5 * (C + 1.0))
C = mindspore.Tensor.from_numpy(C)
img_list = sorted(os.listdir(opts.image_url))
mask_list = sorted(os.listdir(opts.mask_url))
n_samples = opts.condition_num
First Stage —— Transformer
定义网络以及加载权重
from mindspore.train import Model
# Define the model
block_size = opts.prior_size * opts.prior_size
transformer = GPT(vocab_size=C.shape[0], n_embd=opts.n_embd, n_layer=opts.n_layer, n_head=opts.n_head,
block_size=block_size, use_gelu2=opts.GELU_2, embd_pdrop=0.0, resid_pdrop=0.0, attn_pdrop=0.0)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Transformer/ImageNet_best.ckpt'
if os.path.exists(opts.ckpt_path):
print('Start loading the model parameters from %s' % (opts.ckpt_path))
checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
mindspore.load_param_into_net(transformer, checkpoint)
print('Finished load the model')
transformer.set_train(False)
model = Model(transformer)
from PIL import Image
from transformer_utils.util import sample_mask
for x_name, y_name in zip(img_list, mask_list):
# load image
print(x_name)
image_url = os.path.join(opts.image_url, x_name)
x = Image.open(image_url).convert("RGB")
x = x.resize((opts.prior_size, opts.prior_size), resample=Image.BILINEAR)
x = mindspore.Tensor.from_numpy(np.array(x)).view(-1, 3)
x = P.Cast()(x, mindspore.float32)
x = ((x[:, None, :] - C[None, :, :]) ** 2).sum(-1).argmin(1)
# load mask
mask_url = os.path.join(opts.mask_url, y_name)
y = Image.open(mask_url).convert("L")
y = y.resize((opts.prior_size, opts.prior_size), resample=Image.NEAREST)
y = (np.array(y) / 255.) > 0.5
y = mindspore.Tensor.from_numpy(y).view(-1)
y = P.Cast()(y, mindspore.float32)
x_list = [x] * n_samples
x_tensor = P.Stack()(x_list)
y_list = [y] * n_samples
y_tensor = P.Stack()(y_list)
x_tensor = P.Cast()(x_tensor * (1 - y_tensor), mindspore.int32)
outputs = sample_mask(model, x_tensor, y_tensor, length=opts.prior_size * opts.prior_size,
top_k=opts.top_k)
img_name = x_name[:x_name.find('.')] + x_name[x_name.find('.'):]
for i in range(n_samples):
current_url = os.path.join(opts.save_url, 'condition_%d' % (i + 1))
os.makedirs(current_url, exist_ok=True)
current_img = C[outputs[i]].view(opts.prior_size, opts.prior_size, 3).asnumpy().astype(np.uint8)
tmp = Image.fromarray(current_img)
tmp.save(os.path.join(current_url, img_name))
Second Stage —— Upsample
利用Transformer产生的图像先验信息进行图像上采样,恢复到高分辨图片。
def stitch_images(inputs, *outputs, img_per_row=2):
gap = 5
columns = len(outputs) + 1
width, height = inputs[0][:, :, 0].shape
img = Image.new('RGB',
(width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
images = [inputs, *outputs]
for ix in range(len(inputs)):
xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
yoffset = int(ix / img_per_row) * height
for cat in range(len(images)):
im = images[cat][ix].asnumpy().astype(np.uint8).squeeze()
im = Image.fromarray(im)
img.paste(im, (xoffset + cat * width, yoffset))
return img
from upsample_utils.util import postprocess, imsave
opts.mask_type = 3
opts.mode = 2
opts.input = './input/'
opts.kmeans = './kmeans_centers.npy'
# 确保第二阶段的输入与第一阶段的输出相同
opts.prior = opts.save_url
generator = Generator()
generator.set_train(False)
# 根据相应权重文件进行路径修改,如果进行修改,请注意使用绝对路径避免不必要的错误
opts.ckpt_path = './ckpts_ICT/ms_train/Upsample/InpaintingModel_gen_best.ckpt'
if os.path.exists(opts.ckpt_path):
print('Start loading the model parameters from %s' % (opts.ckpt_path))
checkpoint = mindspore.load_checkpoint(opts.ckpt_path)
mindspore.load_param_into_net(generator, checkpoint)
print('Finished load the model')
psnr_func = PSNR(255.0)
test_dataset = load_dataset(image_flist=opts.input, edge_flist=opts.prior, mask_filst=opts.mask,
image_size=opts.image_size, prior_size=opts.prior_size, mask_type=opts.mask_type,
kmeans=opts.kmeans, condition_num=opts.condition_num,
augment=False, training=False)
index = 0
psnr = AverageMeter()
mae = AverageMeter()
test_batch_size = 1
test_dataset = test_dataset.batch(test_batch_size)
for sample in test_dataset.create_dict_iterator():
name = sample['name'].asnumpy()[0]
images = sample['images']
edges = sample['edges']
masks = sample['masks']
inputs = (images * (1 - masks)) + masks
index += test_batch_size
outputs = generator(images, edges, masks)
outputs_merged = (outputs * masks) + (images * (1 - masks))
psnr.update(psnr_func(postprocess(images), postprocess(outputs_merged)), 1)
mae.update((P.ReduceSum()(P.Abs()(images - outputs_merged)) / P.ReduceSum()(images)), 1)
result_merge = stitch_images(
postprocess(images),
postprocess(inputs),
postprocess(outputs_merged),
img_per_row=1
)
result_merge.show()
output = postprocess(outputs_merged)[0]
path = os.path.join(opts.save_url, name[:-4] + "_%d" % (index % opts.condition_num) + '.png')
imsave(output, path)
print('PSNR: {}, MAE: {}'.format(psnr.avg, mae.avg))