Generative AI Series - What is a Generative Adversarial Model (GAN)?

1 How to popularly understand GAN?

Since Generative adversarial network (GAN, Generative adversarial network) was proposed by Ian Goodfellow in 2014, there has been a wave of research. GAN consists of a generator and a discriminator. The generator is responsible for generating samples, and the discriminator is responsible for judging whether the samples generated by the generator are true. The generator should confuse the discriminator as much as possible, and the discriminator should distinguish the samples generated by the generator from the real samples as much as possible.

​ In the original work of GAN [1], the author compares the generator to a criminal who prints counterfeit banknotes, and the discriminator to a policeman. Criminals are working hard to make banknotes look real, and police are improving their ability to spot counterfeit bills. The two compete with each other, and as time goes on, they will become stronger and stronger. Then analogous to the image generation task, the generator continuously generates fake images that are as realistic as possible. The discriminator judges whether the image is a real image or a generated image, and the two are continuously optimized through the game. The images produced by the final generator make it completely impossible for the discriminator to distinguish real from fake.

2 Formal expression of GAN

The above example is just a brief introduction to the idea of ​​GAN. The following is a formal and more specific definition of GAN. Usually, whether it is a generator or a discriminator, we can use a neural network to implement it. Then, we can express the popularized definition with the following model:

insert image description here

​ The left side of the above model is the generator G, whose input is zzz , for the original GAN,zzz is noise randomly sampled from a Gaussian distribution. noisezzz gets generated fake samples through the generator.

​ The generated fake samples and real samples are put together, randomly selected and sent to the discriminator D, and the discriminator distinguishes whether the input samples are generated fake samples or real samples. The whole process is simple and clear. The "generative confrontation" in the generative confrontation network is mainly reflected in the confrontation between the generator and the discriminator.

3 What is the objective function of GAN?

​ For the above neural network model, if you want to learn its parameters, you first need an objective function. The objective function of GAN is defined as follows:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathop {\min }\limits_G \mathop {\max }\limits_D V(D,G) = {\rm E}_{x\sim{p_{data}(x)}}[\log D(x)] + {\rm E}_{z\sim{p_z}(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D ( G ( z ) ) ) ]
​ This objective function can be divided into two parts to understand:

​ Part 1: The optimization of the discriminator passes max ⁡ DV ( D , G ) \mathop {\max}\limits_D V(D,G)DmaxV(D,G ) to achieve,V ( D , G ) V(D,G)V(D,G ) is the objective function of the discriminator, its first itemE x ∼ pdata ( x ) [ log ⁡ D ( x ) ] {\rm E}_{x\sim{p_{data}(x)}}[\ log D(x)]Expdata(x)[logD ( x ) ] represents the mathematical expectation of the probability that the discriminator judges a sample from the real data distribution as a real sample. For samples sampled in the real data distribution, the probability of being predicted as a positive sample is of course as close to 1 as possible. It is therefore desirable to maximize this term. The second termE z ∼ pz ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] {\rm E}_{z\sim{p_z}(z)}[\log (1 - D( G(z)))]Ezpz(z)[log(1D ( G ( z ) ) ] means: for the slave noise P z ( z ) P_z(z)Pz( z ) The samples sampled in the distribution are generated by the generator and then sent to the discriminator. The expected negative logarithm of the predicted probability, this value is naturally the larger the better, the larger the value, the more Closer to 0, it means the better the discriminator.

​ Part II: Generator optimization by min ⁡ G ( max ⁡ DV ( D , G ) ) \mathop {\min }\limits_G({\mathop {\max }\limits_D V(D,G)})Gmin(DmaxV(D,G ) ) to achieve. Note that the goal of the generator is notmin ⁡ GV ( D , G ) \mathop {\min }\limits_GV(D,G)GminV(D,G ) , that is, the generator does not minimize the objective function of the discriminator, but the second is to minimize the maximum value of the objective function of the discriminator. The maximum value of the objective function of the discriminator represents the JS divergence between the real data distribution and the generated data distribution (details You can refer to the derivation in the appendix), JS divergence can measure the similarity of the distribution, the closer the two distributions are, the smaller the JS divergence.

4 What is the difference between the objective function and cross entropy of GAN?

​ The discriminator objective function written in discrete form is:
V ( D , G ) = − 1 m ∑ i = 1 i = mlog D ( xi ) − 1 m ∑ i = 1 i = mlog ( 1 − D ( x ~ i ) ) V(D,G)=-\frac{1}{m}\sum_{i=1}^{i=m}logD(x^i)-\frac{1}{m}\sum_{i =1}^{i=m}log(1-D(\tilde{x}^i))V(D,G)=m1i=1i=mlogD(xi)m1i=1i=mlog(1D(x~i))

It can be seen that this objective function is consistent with cross-entropy, that is, the goal of the discriminator is to minimize the cross-entropy loss, and the goal of the generator is to minimize the JS divergence between the generated data distribution and the real data distribution .


[1]: Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in neural information processing systems. 2014.

5 Why can't the Loss of GAN drop?

​ For many beginners of GAN, they may wonder why the Loss of GAN has not been reduced in practice. When does GAN converge? In fact, as a well-trained GAN, its Loss cannot be reduced. To measure whether the GAN is well trained, the human eye can only see whether the quality of the generated pictures is good. However, many scholars have done some research on the problem of not having a good evaluation index for convergence. The WGAN mentioned later proposes a new Loss design method, which better solves the problem that it is difficult to judge the convergence. question. Let's analyze why the Loss of GAN can't go down?
​ For the discriminator, the Loss of GAN is as follows:
min ⁡ G max ⁡ DV ( D , G ) = E x ∼ pdata ( x ) [ log ⁡ D ( x ) ] + E z ∼ pz ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathop {\min }\limits_G \mathop {\max }\limits_D V(D,G) = {\rm E}_{x\sim{p_{data} (x)}}[\log D(x)] + {\rm E}_{z\sim{p_z}(z)}[\log (1 - D(G(z)))]GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D ( G ( z ) ) ) ]
​ 从min ⁡ G max ⁡ DV ( D , G ) \mathop {\min }\limits_G \mathop {\max }\limits_D V(D,G)GminDmaxV(D,G ) It can be seen that the purpose of the generator and the discriminator are opposite, that is to say, the two generator networks and the discriminator network are against each other, and one ebbs and another grows. It is impossible for Loss to drop to a convergent state.

  • For the generator, its Loss drops rapidly, and it is likely that the discriminator is too weak, causing the generator to easily "fool" the discriminator.
  • For the discriminator, the Loss drops quickly, which means that the discriminator is very strong, and the strong discriminator means that the image generated by the generator is not realistic enough, which makes it easy for the discriminator to distinguish, resulting in a rapid loss of loss.

That is to say, whether it is a discriminator or a generator. The level of loss does not represent the quality of the generator. For a good GAN network, its GAN Loss is often fluctuating.

​ Seeing this may be a bit desperate. It seems that the only way to judge whether the model converges is to look at the quality of the generated image. In fact, the WGAN discussed later proposes a new loss measurement method, which allows us to judge whether the model is converged by certain means.

6 What is the difference between a generative model and a discriminative model?

For machine learning models, we can divide the models into two categories, generative models and discriminative models, according to the way the model models data. If we want to train a model for cat and dog classification, for the discriminative model, we only need to learn the difference between the two. For example, cats are smaller than dogs. The generative model is different. It needs to learn what cats look like and what dogs look like. After having the appearance of the two, distinguish them according to their appearance. in particular:

  • Generative model: learn the joint probability distribution P(X,Y) from the data, and then obtain the probability distribution P(Y|X) from P(Y|X)=P(X,Y)/P(X) as the prediction Model. This method expresses the generative relationship between a given input X and an output Y

  • Discriminant model: directly learn the decision function Y=f(X) or the conditional probability distribution P(Y|X) from the data as the prediction model, that is, the discriminant model. Discriminative methods are concerned with what output Y should be predicted for a given input X.

​ For the above two models, it seems not very intuitive to understand from the text. Let's take an example to illustrate that for the gender classification problem, different models are used to do it:

​ 1) If you use a generative model: you can train a model to learn the relationship between the input person's feature X and gender Y. For example, we have the following batch of data:

Y (gender) 0 1
X (feature) 0 1/4 3/4
1 3/4 1/4

​ This data can be obtained statistically, that is, when the characteristics of a statistical person are X=0,1..., the probability of its category being Y=0,1. After the above joint probability distribution P(X, Y) is obtained statistically, a model can be learned, for example, let the two-dimensional Gaussian distribution fit the above data, so that the joint distribution of X and Y can be learned. When predicting, if we want to give an input feature X and predict its category, we need to obtain the conditional probability distribution through the Bayesian formula to infer: P ( Y ∣ X ) = P ( X , Y ) P (
X ) = P ( X , Y ) P ( X ∣ Y ) P ( Y ) P(Y|X)={\frac{P(X,Y)}{P(X)}}={\frac{P(X, Y)}{P(X|Y)P(Y)}}P(YX)=P(X)P(X,Y)=P(XY)P(Y)P(X,Y)
2) If you use a discriminative model: you can train a model and input human features X, which include human facial features, dressing style, hairstyle, etc. The output is the probability of judging gender. This probability obeys a distribution, and the distribution has only two values, either male or female. Record this distribution as Y. This process learns a conditional probability distribution P(Y|X), that is, the probability distribution of Y when the distribution of the input feature X is known.

​ Obviously, it can be seen from the above analysis. The discriminative model seems to be much more convenient, because the generative model often needs a lot of data to learn a joint distribution of X and Y, while the discriminative model requires relatively little data, because the discriminative model pays more attention to the difference of input features. However, since the generative formula uses more data to generate the joint distribution, it can naturally provide more information. Now there is a sample (X, Y), and its joint probability P (X, Y) is calculated to be very small, so it can be considered This sample is an abnormal sample. This model can be used for outlier detection.

7 What is mode collapsing?

A large number of repeated samples appear in a certain mode, for example:
[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-z57fvObO-1692278058668)(img/ch7/model_collpsing.png)]

​ The blue five-pointed star on the left side of the figure above represents the real sample space, and the yellow one is generated. The generated samples lack diversity and have a lot of duplication. For example, in the right side of the picture above, the characters in the red box appear repeatedly.

8 How to solve mode collapsing?

Method 1: Improved method for the objective function

​ In order to avoid the above-mentioned problem of mode jumping due to optimization of maxmin, UnrolledGAN solves it by modifying the generator loss. Specifically, UnrolledGAN updates the generator k times when updating the generator, and the referenced Loss is not the loss of a certain time, but the loss of k iterations after the discriminator. Note that the next k iterations of the discriminator do not update its own parameters, and only calculate the loss to update the generator. This method allows the generator to take into account the changes of the subsequent k discriminators, avoiding the mode collapse problem caused by switching between different modes. It must be distinguished here from iterating the generator k times and then iterating the discriminator once [8]. DRAGAN introduces the no-regret algorithm in game theory, and transforms its loss to solve the mode collapse problem [9]. The EBGAN mentioned above is to add the reconstruction error of VAE to solve the mode collapse.

Method 2: Improved method for network structure

​ Multi agent diverse GAN (MAD-GAN) uses multiple generators and a discriminator to ensure the diversity of sample generation. The specific structure is as follows:

[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-SaXLAwZa-1692278058669)(img/ch7/MAD_GAN.png)]

​ Compared with ordinary GAN, there are several more generators, and a regular term is added when designing loss. The regularization term penalizes the consistency of samples generated by the three generators using cosine distance.

MRGAN adds a discriminator to punish the mode collapse problem of generating samples. The specific structure is as follows:

[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-TA9538AG-1692278058670)(img/ch7/MRGAN.png)]

​ Input sample xxx is encoded as a hidden variable E ( x ) E(x)by an EncoderE ( x ) , and then the hidden variable is reconstructed by the Generator. During training, there are three Loss. DM D_MDMand RRR (reconstruction error) is used to guide the generation of real-like samples. AndDD D_DDDThen for E ( x ) E(x)E ( x ) sumzz_The samples generated by z are discriminated. Obviously, the samples generated by both are fake samples, so this discriminator is mainly used to judge whether the generated samples are diverse, that is, whether there is mode collapse.

Method 3: Mini-batch Discrimination

​ Mini-batch discrimination establishes a mini-batch layer in the middle layer of the discriminator to calculate the sample statistics based on the L1 distance. By establishing this statistic, it realizes how close a certain sample in a batch is to other samples. This information can be used by the discriminator to identify samples that lack diversity. For generators, try to generate samples with diversity.

Guess you like

Origin blog.csdn.net/qq_43456016/article/details/132351102