Generative Adversarial Networks

栏目: IT技术 · 发布时间: 4年前

内容简介:Generative Adversarial Networks or GANs for short are a type of neural network that can be used to generate data rather than attempt to classify it. Although slightly disturbing, the following site provides an impressive example of GANs can do.A generative

Generative Adversarial Networks or GANs for short are a type of neural network that can be used to generate data rather than attempt to classify it. Although slightly disturbing, the following site provides an impressive example of GANs can do.

A generative adversarial network is composed of two parts. A generator that learns to generate plausible data and a discriminator that learns to distinguish the generator’s fake data from real data. The discriminator will penalize the generator whenever it detects fake data.

Generative Adversarial Networks

The training phase of the discriminator and generator are kept separate. In other words, the weights of the generator remain fixed while it produces examples for the discriminator to train on, and vice versa when it’s time to train the generator. Typically, we alternate between training the discriminator and the generator for one or more epochs.

The discriminator training process is comparable to that of any other neural network. The discriminator classifies both real samples and fake data from the generator. The discriminator loss function penalizes the discriminator for misclassifying a real instance as fake or a fake instance as real, and updates the discriminator’s weights via backpropagation.

Similarly, the generator generates samples which are then classified by the discriminator as being fake or real. The results are then fed into a loss function which penalizes the generator for failing to fool the discriminator and backpropagation is used to modify the generator’s weights.

As the generator improves with training, the discriminator performance gets worse because the discriminator fails to distinguish between real and fake. If the generator succeeds perfectly, then the discriminator has a 50% accuracy (no better than random chance). The later poses a real problem for convergence of the GAN as a whole. If the GAN continues training past the point when the discriminator is giving completely random feedback, then the generator starts to train on junk feedback, and its own performance may be affected.

Python Code

Let’s take a look at how we could go about implementing a generative adversarial network in Python. To begin, we import the following libraries.

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import numpy as np

We’ll be using the MNIST dataset which contains 28 by 28 images of handwritten digits. We create a class called GAN with the following parameters.

class GAN():
    def __init__(self):
        self.image_rows = 28
        self.image_cols = 28
        self.channels = 1
        self.image_shape = (self.image_rows, self.image_cols, self.channels)
        self.input_dim = 100
        optimizer = Adam(0.0002, 0.5)
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
        self.generator = self.build_generator()_in = Input(shape=(self.input_dim,))
        image = self.generator(_in)self.discriminator.trainable = Falsevalidity = self.discriminator(image)self.combined = Model(_in, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

We define the generator network.

def build_generator(self):
        model = Sequential()
        model.add(Dense(256, input_dim=self.input_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.image_shape), activation='tanh'))
        model.add(Reshape(self.image_shape))
        model.summary()
        noise = Input(shape=(self.input_dim,))
        image = model(noise)
        return Model(noise, image)

We define the discriminator network.

def build_discriminator(self):
        model = Sequential()
        model.add(Flatten(input_shape=self.image_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        image = Input(shape=self.image_shape)
        validity = model(image)
        return Model(image, validity)

Next, we define a function to train the model. We begin by normalizing the pixels of each image such that they range from negative to positive one. We use Numpy to create random noise which in turn is used by the generator to produce fake data. The discriminator is trained on the generated data in addition to the samples that are known to be real. Lastly, the generator loss computed by comparing the output against actual samples.

def train(self, epochs, batch_size=128, sample_interval=50):
        (X_train, _), (_, _) = mnist.load_data()
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))for epoch in range(epochs):
            index = np.random.randint(0, X_train.shape[0], batch_size)
            images = X_train[index]
            noise = np.random.normal(0, 1, (batch_size, self.input_dim))
            gen_images = self.generator.predict(noise)
            d_loss_real = self.discriminator.train_on_batch(images, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_images, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)noise = np.random.normal(0, 1, (batch_size, self.input_dim))
            g_loss = self.combined.train_on_batch(noise, valid)print ("%d [Discriminator loss: %f, acc.: %.2f%%] [Generator loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            if epoch % sample_interval == 0:
               self.sample_images(epoch)

We periodically save the output in order to evaluate the model’s performance throughout the training process.

def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.input_dim))
        gen_images = self.generator.predict(noise)
        gen_images = 0.5 * gen_images + 0.5
        fig, axs = plt.subplots(r, c)
        count = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_images[count, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                count += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

Finally, we create an instance of the GAN class and train the model.

gan = GAN()
gan.train(epochs=100000, batch_size=128, sample_interval=10000)

Initially, the output of the GAN is just random noise.

Generative Adversarial Networks

However, by the end, the output begins to look like handwritten digits.

Generative Adversarial Networks


以上所述就是小编给大家介绍的《Generative Adversarial Networks》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

互联网爆破术:快速掌握互联网运营全链条实战技巧

互联网爆破术:快速掌握互联网运营全链条实战技巧

茶文 / 电子工业出版社 / 2018-7 / 49.00元

《互联网爆破术:快速掌握互联网运营全链条实战技巧》是一本实用的互联网运营书籍,可以让读者快速掌握运营全链条的干货技巧和相关模型,涵盖如何有效寻找市场的需求爆破点,通过测试一步步放大并引爆,直至赢利。《互联网爆破术:快速掌握互联网运营全链条实战技巧》非常适合互联网运营人员及互联网创业者阅读,它可以帮读者快速了解互联网运营的核心技巧,并用最低的成本取得成功。本书5大特色:快速入门、实战干货、低成本、系......一起来看看 《互联网爆破术:快速掌握互联网运营全链条实战技巧》 这本书的介绍吧!

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具