深度有趣 | 16 令人拍案叫绝的WGAN

栏目: Python · 发布时间: 7年前

内容简介:在DCGAN的基础上,介绍WGAN的原理和实现,并在LFW和CelebA两个数据集上进一步实践GAN一直面临以下问题和挑战相对于传统的GAN,WGAN只做了以下三点简单的改动

在DCGAN的基础上,介绍WGAN的原理和实现,并在LFW和CelebA两个数据集上进一步实践

问题

GAN一直面临以下问题和挑战

  • 训练困难,需要精心设计模型结构,并小心协调G和D的训练程度
  • G和D的损失函数无法指示训练过程,缺乏一个有意义的指标和生成图片的质量相关联
  • 模式崩坏(mode collapse),生成的图片虽然看起来像是真的,但是缺乏多样性

原理

相对于传统的GAN,WGAN只做了以下三点简单的改动

sigmoid_cross_entropy_with_logits

G的损失函数原本为

其导致的结果是,如果D训练得太好,G将学习不到有效的梯度

但是,如果D训练得不够好,G也学习不到有效的梯度

就像警察如果太厉害,便直接把小偷干掉了;但警察如果不厉害,就无法迫使小偷变得更厉害

因此以上损失函数导致GAN训练特别不稳定,需要小心协调G和D的训练程度

GAN的作者提出了G损失函数的另一个版本,即所谓的 -logD trick

G需要最小化以上损失函数,等价于最小化以下损失函数

其中前者为KL散度(Kullback–Leibler Divergence)

后者为JS散度(Jensen-Shannon Divergence)

两者都可以用于衡量两个分布之间的距离,越小说明两个分布越相似

因此以上损失函数,一方面要减小KL散度,另一方面却要增大JS散度,一边拉近一边推远,从而导致训练不稳定

除此之外,KL散度的不对称性,导致对以下两种情况的不同惩罚

  • G生成了不真实的图片,即缺乏准确性,惩罚较高
  • G生成了和真实图片类似的图片,即缺乏多样性,惩罚较低

从而导致,G倾向于生成一些有把握但相似的图片,而不敢轻易地尝试去生成没把握的新图片,即所谓的mode collapse问题

WGAN所做的三点改动,解决了GAN训练困难和不稳定、mode collapse等问题,而且G的损失函数越小,对应生成的图片质量就越高

WGAN训练过程如下,gradient penalty使得D满足1-Lipschitz连续条件,详细原理和细节可以阅读相关论文进一步了解

深度有趣 | 16 令人拍案叫绝的WGAN

论文中部分实验结果如下,WGAN虽然需要更长的训练时间,但收敛更加稳定

深度有趣 | 16 令人拍案叫绝的WGAN

更重要的是,WGAN提供了一种更稳定的GAN框架。DCGAN中的G去掉Batch Normalization就会崩掉,但WGAN则没有这种限制

如果用Deep Convolutional结构实现WGAN,那么其结果和DCGAN差不多。但是在WGAN的框架下,可以用更深更复杂的网络实现G和D,例如ResNet( arxiv.org/abs/1512.03… ),从而达到更好的生成效果

数据

还是之前使用过的两个人脸数据集

  • LFW: vis-www.cs.umass.edu/lfw/ ,Labeled Faces in the Wild,包括1680人共计超过1.3W张图片
  • CelebA: mmlab.ie.cuhk.edu.hk/projects/Ce… ,CelebFaces Attributes Dataset,包括10177人共计超过20W张图片,并且每张图片还包括人脸的5个关键点位置和40个属性的01标注,例如是否有眼镜、帽子、胡子等

实现

加载库

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline
from imageio import imread, imsave, mimsave
import cv2
import glob
from tqdm import tqdm
复制代码

选择数据集

dataset = 'lfw_new_imgs' # LFW
# dataset = 'celeba' # CelebA
images = glob.glob(os.path.join(dataset, '*.*')) 
print(len(images))
复制代码

定义一些常量、网络输入、辅助函数

batch_size = 100
z_dim = 100
WIDTH = 64
HEIGHT = 64
LAMBDA = 10
DIS_ITERS = 3 # 5

OUTPUT_DIR = 'samples_' + dataset
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

X = tf.placeholder(dtype=tf.float32, shape=[batch_size, HEIGHT, WIDTH, 3], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='noise')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')

def lrelu(x, leak=0.2):
    return tf.maximum(x, leak * x)
复制代码

判别器部分,注意需要去掉Batch Normalization,否则会导致batch之间的相关性,从而影响gradient penalty的计算

def discriminator(image, reuse=None, is_training=is_training):
    momentum = 0.9
    with tf.variable_scope('discriminator', reuse=reuse):
        h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same'))
        
        h1 = lrelu(tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same'))
        
        h2 = lrelu(tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same'))
        
        h3 = lrelu(tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same'))
        
        h4 = tf.contrib.layers.flatten(h3)
        h4 = tf.layers.dense(h4, units=1)
        return h4
复制代码

生成器部分

def generator(z, is_training=is_training):
    momentum = 0.9
    with tf.variable_scope('generator', reuse=None):
        d = 4
        h0 = tf.layers.dense(z, units=d * d * 512)
        h0 = tf.reshape(h0, shape=[-1, d, d, 512])
        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum))
        
        h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same')
        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))
        
        h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same')
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))
        
        h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same')
        h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))
        
        h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=3, strides=2, padding='same', activation=tf.nn.tanh, name='g')
        return h4
复制代码

损失函数

g = generator(noise)
d_real = discriminator(X)
d_fake = discriminator(g, reuse=True)

loss_d_real = -tf.reduce_mean(d_real)
loss_d_fake = tf.reduce_mean(d_fake)
loss_g = -tf.reduce_mean(d_fake)
loss_d = loss_d_real + loss_d_fake

alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolates = alpha * X + (1 - alpha) * g
grad = tf.gradients(discriminator(interpolates, reuse=True), [interpolates])[0]
slop = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1]))
gp = tf.reduce_mean((slop - 1.) ** 2)
loss_d += LAMBDA * gp

vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
复制代码

优化函数

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)
    optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)
复制代码

读取图片的函数

def read_image(path, height, width):
    image = imread(path)
    h = image.shape[0]
    w = image.shape[1]
    
    if h > w:
        image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
    else:
        image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]    
    
    image = cv2.resize(image, (width, height))
    return image / 255.
复制代码

合成图片的函数

def montage(images):    
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m
复制代码

随机产生批数据的函数

def get_random_batch(nums):
    img_index = np.arange(len(images))
    np.random.shuffle(img_index)
    img_index = img_index[:nums]
    batch = np.array([read_image(images[i], HEIGHT, WIDTH) for i in img_index])
    batch = (batch - 0.5) * 2
    
    return batch
复制代码

模型的训练

sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
samples = []
loss = {'d': [], 'g': []}

for i in tqdm(range(60000)):
    for j in range(DIS_ITERS):
        n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
        batch = get_random_batch(batch_size)
        _, d_ls = sess.run([optimizer_d, loss_d], feed_dict={X: batch, noise: n, is_training: True})
    
    _, g_ls = sess.run([optimizer_g, loss_g], feed_dict={X: batch, noise: n, is_training: True})
    
    loss['d'].append(d_ls)
    loss['g'].append(g_ls)
    
    if i % 500 == 0:
        print(i, d_ls, g_ls)
        gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False})
        gen_imgs = (gen_imgs + 1) / 2
        imgs = [img[:, :, :] for img in gen_imgs]
        gen_imgs = montage(imgs)
        plt.axis('off')
        plt.imshow(gen_imgs)
        imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i), gen_imgs)
        plt.show()
        samples.append(gen_imgs)

plt.plot(loss['d'], label='Discriminator')
plt.plot(loss['g'], label='Generator')
plt.legend(loc='upper right')
plt.savefig(os.path.join(OUTPUT_DIR, 'Loss.png'))
plt.show()
mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=10)
复制代码

LFW人脸生成结果如下,和DCGAN相比更加稳定

深度有趣 | 16 令人拍案叫绝的WGAN

CelebA人脸生成结果如下

深度有趣 | 16 令人拍案叫绝的WGAN

保存模型,便于后续使用

saver = tf.train.Saver()
saver.save(sess, os.path.join(OUTPUT_DIR, 'wgan_' + dataset), global_step=60000)
复制代码

在单机上使用模型生成人脸图片

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

batch_size = 100
z_dim = 100
# dataset = 'lfw_new_imgs'
dataset = 'celeba'

def montage(images):    
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    if len(images.shape) == 4 and images.shape[3] == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
    elif len(images.shape) == 4 and images.shape[3] == 1:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
    elif len(images.shape) == 3:
        m = np.ones(
            (images.shape[1] * n_plots + n_plots + 1,
             images.shape[2] * n_plots + n_plots + 1)) * 0.5
    else:
        raise ValueError('Could not parse image shape of {}'.format(images.shape))
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.import_meta_graph(os.path.join('samples_' + dataset, 'wgan_' + dataset + '-60000.meta'))
saver.restore(sess, tf.train.latest_checkpoint('samples_' + dataset))
graph = tf.get_default_graph()
g = graph.get_tensor_by_name('generator/g/Tanh:0')
noise = graph.get_tensor_by_name('noise:0')
is_training = graph.get_tensor_by_name('is_training:0')

n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, :] for img in gen_imgs]
gen_imgs = montage(imgs)
gen_imgs = np.clip(gen_imgs, 0, 1)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(gen_imgs)
plt.show()
复制代码

以上所述就是小编给大家介绍的《深度有趣 | 16 令人拍案叫绝的WGAN》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Practical Algorithms for Programmers

Practical Algorithms for Programmers

Andrew Binstock、John Rex / Addison-Wesley Professional / 1995-06-29 / USD 39.99

Most algorithm books today are either academic textbooks or rehashes of the same tired set of algorithms. Practical Algorithms for Programmers is the first book to give complete code implementations o......一起来看看 《Practical Algorithms for Programmers》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具