1、GAN的原理和结构
生成器不仅可以用于生成各类图像和自然语言数据,还推动了各类半监督学习和无监督学习的发展。
机器学习的模型分为:
(1)生成模型:生成我们想要的类型图像?GAN?VAE?
(2)判别模型:判断一张图片是猫还是狗?
GAN是一种深度神经网络架构,由:生成网络和判别网络组成。
生成网络学习真实图像中的分布,以产生“假”数据,并试图欺骗判别网络;
判别网络对生成数据进行真伪鉴别,试图正确识别所有“假”数据;
两个网络在训练迭代的过程中,不断进行对抗和进化;
生成器会生成越来越逼真的数据,判别器也会越来越精确的去判别;
直至训练结束,达到平衡,最终的生成模型生成的数据非常接近真实,以至于判别器无法识别到底为真还是为假。
整个生成对抗模型中,判别器希望对生成器生成的图像判别为0,对真实图像判别为1;而对生成器而言,它希望生成的图像被判别为1;这就产生了一种对抗;因为这两类模型的优化目标是不一致的。
关键:
就是损失函数的处理。
判别器。主要是判别一张图片是真是假的二分类问题;
生成器。损失函数的定义就不是那么容易的了!我们希望生成器可以生成接近真实的图像,我们肉眼可以很好判别是否真实,但是在代码中,很抽象,很难用数学公式来定义。
所以我们就将生成器的输出交给判别器来进行判别处理,这样就将生成器、对抗器组合成了“生成对抗网络”。
2、算法流程和公式
G是生成图片的网络,接受一个随机噪声z,通过噪声生成图片G(z)。
(输入噪声,可以增加随机性和多样性);
D是判断网络,输入一张图片x,输出D(x)为真实图片的概率;为1就是100%是真实图片;输出为0,就代表肯定不是真实图片。
在训练过程中,将随机噪声输入生成网络G,得到生成的图片;判别器接收生成的图片和真实的图片,并尽量将两者区分开来。在这个计算过程中,
能否正确区分生成的图片和真实的图片将作为判别器的损失;
而能否生成近似真实的图片并使得判别器将生成的图片判定为真将作为生成器的提失。
生成器的损失是通过判别器的输出来计算的,
而判别器的输出是一个概率值,我们可以通过交叉熵计算。
Goodfellow从理论上证明了GAN算法的收敛性以及在模型收敛时生成数据具有和真实数据相同的分布。GAN的公式如图:
3、GAN的代码解析
以手写数字minist为例,生成手写数字图。
3.1 鉴别器D的模型搭建
鉴别器主要是输入真实图片,或者是经过噪声生成的虚假图片。
如果比如是:猫狗二分类,输出为[batch_size, 2],然后与对应的标签求loss。
# 鉴别器
class Dis28x28(nn.Module):
def __init__(self):
super(Dis28x28, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 20, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(20, 50, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(50, 500, kernel_size=4, stride=1, padding=0),
nn.PReLU(),
nn.Conv2d(500, 2, kernel_size=1, stride=1, padding=0),
)
def forward(self, x):
out = self.model(x)
return out.squeeze()
3.2 生成器G的模型搭建
生成器主要是输入噪声,生成与真实图片维度、尺寸一样的虚假图片。
# 生成器
class Gen28x28(nn.Module):
def __init__(self, latent_dims):
super(Gen28x28, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(latent_dims, 1024, kernel_size=4, stride=1),
nn.BatchNorm2d(1024, affine=False),
nn.PReLU(),
nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512, affine=False),
nn.PReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256, affine=False),
nn.PReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128, affine=False),
nn.PReLU(),
nn.ConvTranspose2d(128, 1, kernel_size=6, stride=1, padding=1),
nn.Sigmoid())
def forward(self, x):
x = x.view(x.size(0), x.size(1), 1, 1)
out = self.model(x)
return out
3.3 GAN的模型搭建
class MNISTGanTrainer(object):
def __init__(self, batch_size=64, latent_dims=100):
super(MNISTGanTrainer, self).__init__()
# 实例化鉴别器
self.dis = Dis28x28()
# 实例化生成器
self.gen = Gen28x28(latent_dims)
# 鉴别器网络训练时的优化器
self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.0005)
# 生成器网络训练时的优化器
self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.0005)
# 真实图片对应的标签,1,与batch_size一致
self.true_labels = Variable(torch.LongTensor(np.ones(batch_size, dtype=np.int)))
# 噪声生成的虚假图片对应的标签,0,
self.fake_labels = Variable(torch.LongTensor(np.zeros(batch_size, dtype=np.int)))
self.dis.apply(xavier_weights_init)
self.gen.apply(xavier_weights_init)
# 鉴别器的前向传播。其实也就是3步:
def dis_update(self, images, noise):
self.dis.zero_grad()
# 1.真实图片输入进D的输出,与真实标签1求loss。
true_outputs = self.dis(images)
true_loss = nn.functional.cross_entropy(true_outputs, self.true_labels)
_, true_predicts = torch.max(true_outputs.data, 1)
true_acc = (true_predicts == 1).sum()/(1.0*true_predicts.size(0))
# 2.经过G生成的虚假图片输入进D的输出,与真实标签0求loss。
fake_images = self.gen(noise)
fake_outputs = self.dis(fake_images)
fake_loss = nn.functional.cross_entropy(fake_outputs, self.fake_labels)
_, fake_predicts = torch.max(fake_outputs.data, 1)
fake_acc = (fake_predicts == 0).sum() / (1.0 * fake_predicts.size(0))
# 3.鉴别器的总loss进行反向传播,目的是让鉴别器能够理想地鉴别出真与假。
d_loss = true_loss + fake_loss
d_loss.backward()
self.dis_opt.step()
return 0.5 * (true_acc + fake_acc)
# 鉴别器的前向传播。其实也就是2步:
def gen_update(self, noise):
self.gen.zero_grad()
# 1.将噪声输入进G,生成与真实图片维度一致的虚假图片。
fake_images = self.gen(noise)
# 2.将虚假图片输入进D,但是,这里的输出与真实标签1进行求loss。
# 目的就是产生“生成与对抗”的过程。
fake_outputs = self.dis(fake_images)
fake_loss = nn.functional.cross_entropy(fake_outputs, self.true_labels)
fake_loss.backward()
self.gen_opt.step()
return fake_images
3.4 train.py
提供一个训练GAN的数字生成脚本代码。
训练完成,可以:
1.使用生成器进行图像生成等“生成式”任务;
2.使用鉴别器进行图像分类等“鉴别”任务。
# 实例化手写数字识别数据集
train_dataset = dsets.MNIST(root='../data',
train=True,
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=config.batch_size,
shuffle=True)
# 实例化GAN网络模型
trainer = MNISTGanTrainer(config.batch_size, config.latent_dims)
trainer.cuda()
# 开始训练GAN
iterations = 0
while iterations < config.max_iter:
for it, (images, labels) in enumerate(train_loader):
if images.size(0) != config.batch_size:
continue
images = Variable(images.cuda())
noise = Variable(torch.randn(config.batch_size, config.latent_dims)).cuda()
accuracy = trainer.dis_update(images, noise)
noise = Variable(torch.randn(config.batch_size, config.latent_dims)).cuda()
fake_images = trainer.gen_update(noise)
# 进行权重保存,并保存G的生成结果。
if iterations % config.snapshot_iter == 0 and iterations > 0:
dirname = os.path.dirname(config.snapshot_prefix)
if not os.path.isdir(dirname):
os.mkdir(dirname)
img_filename = '%s_gen_%08d.jpg' % (config.snapshot_prefix, iterations)
torchvision.utils.save_image(config.scale*(fake_images.data-config.bias), img_filename)
gen_filename = '%s_gen_%08d.pkl' % (config.snapshot_prefix, iterations)
dis_filename = '%s_dis_%08d.pkl' % (config.snapshot_prefix, iterations)
print("Save generator to %s" % gen_filename)
print("Save discriminator to %s" % dis_filename)
torch.save(trainer.gen.state_dict(), gen_filename)
torch.save(trainer.dis.state_dict(), dis_filename)
if iterations >= config.max_iter:
break
iterations += 1