Générer un modèle - GAN

1. Modèle génératif et modèle discriminant

Modèle génératif et modèle discriminant
Nos chapitres précédents ont principalement introduit le modèle discriminant dans l'apprentissage automatique. La forme de ce modèle est principalement de déduire certaines propriétés de l'image basée sur l'image originale, comme déduire le nom du numéro basé sur l'image numérique , et basé sur l'image de la scène naturelle.Déduire les limites des objets ;

Le modèle génératif est tout le contraire : l'entrée est généralement donnée comme la propriété de l'image et la sortie est l'image correspondant à la propriété. Ce type de modèle génératif équivaut à construire la distribution d'images, donc en utilisant ce type de modèle, nous pouvons compléter le travail de génération automatique d'images (échantillonnage), de complétion des informations sur les images, etc.

Il y a eu de nombreux modèles génératifs avant l'apprentissage profond, mais comme les modèles génératifs sont difficiles à décrire et à modéliser, les chercheurs ont rencontré de nombreux défis et l'émergence de l'apprentissage profond les a aidés à résoudre de nombreux problèmes.

Modèles génératifs basés sur des idées d'apprentissage profond - GAN et VAE, et variantes de modèles de GAN.

1.1 Génération du modèle

  • générer une image
  • génération de visages
  • génération de photos
  • Générer des personnages de dessins animés
  • conversion d'images
  • Conversion de texte en image
  • Conversion sémantique d'image en photo
  • Génération d'images de visage
  • Générer de nouvelles poses humaines
  • Conversion de photos en Emoji
  • retouche photo
  • mélange d'images
  • super résolution
  • réparation d'image
  • changement de vêtements
  • prédiction vidéo
  • Génération d'objets 3D

insérer la description de l'image ici

insérer la description de l'image ici

insérer la description de l'image ici

insérer la description de l'image ici

DEUX PIEDS

VAE-Variational Autoencoder
Variational Autoencoder
Imaginez un tel réseau, l'entrée est un ensemble de vecteurs tous 1, la cible est une tête de chat, après plusieurs cycles d'entraînement. Il suffit de saisir ce vecteur avec des 1 pour obtenir le visage de ce chat.

En fait, cela est dû au fait que pendant le processus de formation, grâce à une formation continue, le réseau a enregistré les paramètres de cette image de chat.

insérer la description de l'image ici

En fait, le sens de ce travail est déjà visible : grâce à un réseau, un visage dans un espace de grande dimension est mappé sur un vecteur dans un espace de basse dimension.

Donc si, nous essayons d'utiliser plus d'images. Cette fois, nous utilisons des vecteurs uniques au lieu de tous. Nous utilisons [1, 0, 0, 0] pour les chats et [0, 1, 0, 0] pour les chiens. Même si cela convient, nous ne pouvons stocker que jusqu'à 4 images.

Ainsi, nous pouvons augmenter la longueur du vecteur et les paramètres du réseau, nous pouvons alors obtenir plus d’images.

Par exemple, si ce vecteur est défini comme quadridimensionnel et que quatre visages différents sont exprimés à l'aide d'une expression unique, alors ce réseau peut exprimer quatre visages. Saisissez différentes données, il produira différents visages.

insérer la description de l'image ici

Cependant, ces vecteurs sont rares. Pour résoudre ce problème, nous souhaitons utiliser des vecteurs de valeurs réelles au lieu de vecteurs de 0, 1. On peut considérer ce vecteur de valeur réelle comme une sorte d’encodage de l’image originale, ce qui conduit au concept d’encodage/décodage.

Par exemple, [3.3, 4.5, 2.1, 9.8] représente un chat et [3.4, 2.1, 6.7, 4.2] représente un chien.

Ce vecteur initial connu peut être notre variable latente.

Ce n'est pas un bon moyen d'initialiser aléatoirement certains vecteurs pour représenter l'encodage des images comme je l'ai fait ci-dessus. Nous espérons que l'ordinateur pourra encoder automatiquement pour nous. Dans le modèle d'encodeur automatique, nous ajoutons un encodeur, qui peut nous aider à encoder des images en vecteurs. Le décodeur peut alors restituer ces vecteurs en images.

insérer la description de l'image ici

Dans la figure ci-dessous, nous décrivons la forme finale du visage à travers six facteurs, et différentes valeurs de ces facteurs représentent différentes caractéristiques.

insérer la description de l'image ici

3. CEPENDANT

3.1 Réseau contradictoire de génération GAN

Qu'est-ce que le Generative Adversarial Network, GAN – Generative Adversarial Network,

  1. Le réseau de confrontation comporte un générateur (Generator) et un discriminateur (Discriminator) ;
  2. Le générateur génère des images à partir de bruit aléatoire. Puisque ces images sont toutes imaginées par le générateur, nous appelons cela Fake Image ;
  3. La photo fausse image générée par le générateur et l'image réelle dans l'ensemble de formation seront transmises au discriminateur, et le discriminateur jugera si elles sont réelles ou fausses.

Alors, comment former le réseau ? Quel genre d’objectif souhaitez-vous atteindre ?

  1. Nous espérons que les images générées par le générateur sont suffisamment réalistes pour tromper le discriminateur ;
  2. Nous espérons également que le discriminateur sera suffisamment « intelligent » pour faire la distinction entre les images réelles et les images générées ;
  3. Finalement, lors de l'entraînement, le générateur et le discriminateur atteignent un équilibre en « confrontation », et l'entraînement se termine.
  4. A ce moment, nous séparons le générateur, ce qui peut nous aider à « générer » l'image souhaitée.

insérer la description de l'image ici

Nous devons comprendre 2 problèmes lors de l'utilisation du GAN

  1. qu'avons-nous
    Par exemple, dans l'image ci-dessus, tout ce que nous avons, c'est l'ensemble de données d'échantillon de visage réel, c'est tout, et le point clé est que nous n'avons même pas les étiquettes de classe de l'ensemble de données de visage, c'est-à-dire que nous n'avons pas savoir quel visage correspond à Qui est-ce.
  2. Qu'allons-nous obtenir ?
    Quant à ce qu'il faut obtenir, différentes tâches obtiennent des choses différentes. Nous ne parlons que de l'objectif le plus primitif du GAN, c'est-à-dire que nous voulons simuler une image de visage en entrant un bruit. Cette image peut être si réaliste qu'elle peut être confondue avec le vrai.

Tout d'abord, le modèle discriminant est le réseau dans la moitié droite de la figure. Intuitivement, il s'agit d'une structure de réseau neuronal simple. L'entrée est une image et la sortie est une valeur de probabilité, qui est utilisée pour juger si elle est vrai ou faux (si la valeur de probabilité est supérieure à 0,5, c'est vrai. , moins de 0,5 est faux), vrai et faux ne sont que la probabilité définie par les gens.

Le second est le modèle de génération, qui peut également être considéré comme un modèle de réseau neuronal : l'entrée est un ensemble de nombres aléatoires Z et la sortie est une image au lieu d'une valeur.

Comme le montre la figure, il y aura deux ensembles de données, l’un est le véritable ensemble de données et l’autre est le faux ensemble de données.

Objectifs des GAN :

  1. Le but du réseau discriminant est de pouvoir distinguer si une image d’entrée provient d’un ensemble d’échantillons réel ou d’un faux ensemble d’échantillons. Si l'entrée est un échantillon réel, la sortie réseau sera proche de 1, si l'entrée est un faux échantillon, la sortie réseau sera proche de 0, ce qui permet d'atteindre l'objectif d'une bonne discrimination.
  2. Le but du réseau de génération : le réseau de génération est de fabriquer des échantillons, et son but est de rendre la capacité de fabrication d'échantillons aussi forte que possible et de rendre impossible au réseau discriminant de juger s'il s'agit d'un échantillon réel ou d'un faux. échantillon.

Les objectifs du réseau de génération et du réseau discriminant sont exactement opposés : l'un dit que je peux bien discriminer, et l'autre dit que je vous laisse mal discriminer.

C'est pour ça qu'on appelle ça une confrontation, ça s'appelle un jeu.

Alors, qui gagnera à la fin ?

Cela dépend du designer, que nous voulons gagner.

En tant que concepteurs, notre objectif est d'obtenir des échantillons qui ressemblent à des échantillons réels, nous espérons donc naturellement que les échantillons générés gagneront, c'est-à-dire que nous espérons que les échantillons générés sont réels et que la capacité du réseau discriminant n'est pas suffisante pour distinguer entre les vrais et les faux échantillons.

3.2 GAN-Générer la formation du réseau de confrontation

Formation itérative alternative seule
insérer la description de l'image ici

3.2.1 Formation sur le modèle discriminant :

En supposant que le modèle de réseau généré existe déjà (bien sûr, ce n'est peut-être pas le meilleur réseau généré), alors étant donné un tas de tableaux aléatoires, un tas de faux ensembles d'échantillons seront obtenus (car ce n'est pas le modèle généré final, le Le réseau généré peut maintenant être dans le réseau discriminant. L'inconvénient est que les échantillons générés ne sont pas très bons et peuvent être facilement identifiés par le réseau discriminant et dire que le produit est contrefait).

Supposons que nous ayons maintenant un faux ensemble d'échantillons et que le véritable ensemble d'échantillons ait toujours été là. Maintenant, nous définissons artificiellement les étiquettes des ensembles d'échantillons vrais et faux, car nous voulons que la sortie de l'ensemble d'échantillons réel soit égale à 1 autant que possible, et le faux ensemble d'échantillons est égal à 0. Évidemment, nous avons supposé ici que toutes les étiquettes de classe du véritable ensemble d'échantillons sont 1 et que toutes les étiquettes de classe du faux ensemble d'échantillons sont 0.

Nous avons donc maintenant de vrais échantillons et leurs étiquettes (tous deux 1), de faux échantillons et leurs étiquettes (tous deux 0)

De cette façon, en ce qui concerne le réseau discriminant, le problème devient alors un simple problème de classification binaire supervisée, qui peut être directement envoyé au modèle de réseau neuronal pour la formation.

3.2.2 Formation du réseau de production :

Pensez à notre objectif, générer des échantillons aussi réalistes que possible.
Alors, comment savoir si l’échantillon généré par le réseau génératif d’origine est vrai ?
Il est envoyé au réseau discriminant, donc lors de la formation du réseau génératif, nous devons combiner le réseau discriminant pour atteindre l'objectif de la formation.
Connectez tout à l'heure le réseau discriminant derrière le réseau générateur, afin que nous connaissions la vérité et le mensonge, et qu'il y ait des erreurs.

Par conséquent, la formation du réseau de génération est en fait la formation de la concaténation du réseau de génération-discrimination.

Pour les échantillons, nous devons définir les étiquettes des faux échantillons générés sur 1 , ce qui signifie que ces faux échantillons sont considérés comme de vrais échantillons lorsque le réseau est formé.

Alors pourquoi? Réfléchissons-y, est-il possible de confondre le discriminateur de cette manière, et également de faire en sorte que les faux échantillons générés se rapprochent progressivement des vrais échantillons.

Maintenant, pour la formation du réseau généré, nous avons un ensemble d'échantillons (uniquement un faux ensemble d'échantillons, pas de véritable ensemble d'échantillons) et une étiquette correspondante (tous 1).

Notez que lors de la formation de ce réseau concaténé, une opération très importante n'est pas de mettre à jour les paramètres du réseau discriminant, mais de transmettre l'erreur jusqu'au bout, puis de mettre à jour les paramètres du réseau généré après l'avoir transmise au réseau généré.

Après avoir terminé la formation sur le réseau de génération, nous pouvons générer de nouveaux faux échantillons pour le bruit Z précédent sur la base du réseau de nouvelle génération actuel.

Et les faux échantillons après la formation devraient être plus réels.

De cette façon, nous disposons d’un nouvel ensemble d’échantillons vrais et faux, de sorte que le processus ci-dessus puisse être répété.

Nous appelons ce processus la formation individuelle en alternance .

4. LeakyReLU

Lorsque la valeur d'entrée de Relu est négative, la sortie est toujours 0 et sa dérivée première est toujours 0, ce qui empêchera le neurone de mettre à jour les paramètres, c'est-à-dire que le neurone n'apprendra pas. Ce phénomène est appelé "mort". Neurone".

Afin de résoudre le défaut de la fonction Relu, une valeur de fuite (Leaky) est introduite dans le demi-intervalle négatif de la fonction Relu, elle est donc appelée fonction Leaky Relu. Autrement dit, ReLU n'a pas de dégradé dans la partie où la valeur est inférieure à zéro, et LeakyReLU donne un petit dégradé dans la partie où la valeur est inférieure à 0.
insérer la description de l'image ici

5. Exemple de code GAN

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("./images/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=2000, batch_size=32, sample_interval=200)

Je suppose que tu aimes

Origine blog.csdn.net/m0_63260018/article/details/132440111
conseillé
Classement