内容简介:除VAE之外,生成式对抗网络(Generative Adversarial Nets,GAN)也是一种非常流行的无监督生成式模型GAN中主要包括两个核心网络GAN的训练非常困难,有很多细节需要注意,才能生成质量较高的图片
除VAE之外,生成式对抗网络(Generative Adversarial Nets,GAN)也是一种非常流行的无监督生成式模型
GAN中主要包括两个核心网络
- 生成器(Generator):记作G,通过对大量样本的学习,能够生成一些以假乱真的样本,和VAE类似
- 判别器(Discriminator):记作D,接受真实样本和G生成的样本,并进行判别和区分
- G和D相互博弈,通过学习,G的生成能力和D的判别能力都逐渐增强并收敛
GAN的训练非常困难,有很多细节需要注意,才能生成质量较高的图片
strides
这里我们以 MNIST
为例,通过 TensorFlow
实现GAN,由于用到深度卷积神经网络,所以也称作DCGAN(Deep Convolutional GAN)
原理
对于一个服从随机分布的噪音z,生成器通过一个复杂的映射函数生成假的样本
\hat{x}=G(z;\theta_g) 复制代码
判别器则使用另一个复杂的映射函数,对于真实样本或假的样本,输出一个0至1之间的值,越大表示越有可能是真实的样本
s=D(x;\theta_d) 复制代码
总的目标函数如下
\min_{G}\max_{D} V(D,G)=\mathbb{E}_{x\sim p_{data}}[\log D(x)] + \mathbb{E}_{z\sim p_z}[\log(1-D(G(z)))] 复制代码
实现
加载库
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt %matplotlib inline import os, imageio 复制代码
加载数据
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data') 复制代码
定义一些常量、网络输入、辅助函数
batch_size = 100 z_dim = 100 OUTPUT_DIR = 'samples' if not os.path.exists(OUTPUT_DIR): os.mkdir(OUTPUT_DIR) X = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='X') noise = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='noise') is_training = tf.placeholder(dtype=tf.bool, name='is_training') def lrelu(x, leak=0.2): return tf.maximum(x, leak * x) def sigmoid_cross_entropy_with_logits(x, y): return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) 复制代码
判别器部分
def discriminator(image, reuse=None, is_training=is_training): momentum = 0.9 with tf.variable_scope('discriminator', reuse=reuse): h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same')) h1 = tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same') h1 = lrelu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum)) h2 = tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same') h2 = lrelu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum)) h3 = tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same') h3 = lrelu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum)) h4 = tf.contrib.layers.flatten(h3) h4 = tf.layers.dense(h4, units=1) return tf.nn.sigmoid(h4), h4 复制代码
生成器部分
def generator(z, is_training=is_training): momentum = 0.9 with tf.variable_scope('generator', reuse=None): d = 3 h0 = tf.layers.dense(z, units=d * d * 512) h0 = tf.reshape(h0, shape=[-1, d, d, 512]) h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum)) h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same') h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum)) h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same') h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum)) h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same') h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum)) h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=1, strides=1, padding='valid', activation=tf.nn.tanh, name='g') return h4 复制代码
定义损失函数,注意这里实现了两个判别器,但参数是共享的
g = generator(noise) d_real, d_real_logits = discriminator(X) d_fake, d_fake_logits = discriminator(g, reuse=True) vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')] vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')] loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits, tf.ones_like(d_real))) loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake))) loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.ones_like(d_fake))) loss_d = loss_d_real + loss_d_fake 复制代码
定义优化函数,注意损失函数需要和可调参数对应上
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d) optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g) 复制代码
定义一个辅助函数,用于将多张图片以网格状拼在一起显示
def montage(images): if isinstance(images, list): images = np.array(images) img_h = images.shape[1] img_w = images.shape[2] n_plots = int(np.ceil(np.sqrt(images.shape[0]))) m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5 for i in range(n_plots): for j in range(n_plots): this_filter = i * n_plots + j if this_filter < images.shape[0]: this_img = images[this_filter] m[1 + i + i * img_h:1 + i + (i + 1) * img_h, 1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img return m 复制代码
开始训练,每次迭代训练G两次
sess = tf.Session() sess.run(tf.global_variables_initializer()) z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32) samples = [] loss = {'d': [], 'g': []} for i in range(60000): n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32) batch = mnist.train.next_batch(batch_size=batch_size)[0] batch = np.reshape(batch, [-1, 28, 28, 1]) batch = (batch - 0.5) * 2 d_ls, g_ls = sess.run([loss_d, loss_g], feed_dict={X: batch, noise: n, is_training: True}) loss['d'].append(d_ls) loss['g'].append(g_ls) sess.run(optimizer_d, feed_dict={X: batch, noise: n, is_training: True}) sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True}) sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True}) if i % 1000 == 0: print(i, d_ls, g_ls) gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False}) gen_imgs = (gen_imgs + 1) / 2 imgs = [img[:, :, 0] for img in gen_imgs] gen_imgs = montage(imgs) plt.axis('off') plt.imshow(gen_imgs, cmap='gray') plt.savefig(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i)) plt.show() samples.append(gen_imgs) plt.plot(loss['d'], label='Discriminator') plt.plot(loss['g'], label='Generator') plt.legend(loc='upper right') plt.savefig('Loss.png') plt.show() imageio.mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=5) 复制代码
生成的图片如下,由于损失函数中并未使用到逐像素比较,因此图形边缘不会出现模糊
保存模型,便于后续使用
saver = tf.train.Saver() saver.save(sess, './mnist_dcgan', global_step=60000) 复制代码
加载模型,如果需要的话,例如在单机上使用
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import matplotlib.pyplot as plt batch_size = 100 z_dim = 100 def montage(images): if isinstance(images, list): images = np.array(images) img_h = images.shape[1] img_w = images.shape[2] n_plots = int(np.ceil(np.sqrt(images.shape[0]))) m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5 for i in range(n_plots): for j in range(n_plots): this_filter = i * n_plots + j if this_filter < images.shape[0]: this_img = images[this_filter] m[1 + i + i * img_h:1 + i + (i + 1) * img_h, 1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img return m sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.import_meta_graph('./mnist_dcgan-60000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) graph = tf.get_default_graph() g = graph.get_tensor_by_name('generator/g/Tanh:0') noise = graph.get_tensor_by_name('noise:0') is_training = graph.get_tensor_by_name('is_training:0') n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32) gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False}) gen_imgs = (gen_imgs + 1) / 2 imgs = [img[:, :, 0] for img in gen_imgs] gen_imgs = montage(imgs) plt.axis('off') plt.imshow(gen_imgs, cmap='gray') plt.show() 复制代码
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:- 对抗攻击之利用水印生成对抗样本
- 基于梯度扰动探索对抗攻击与对抗样本
- WriteUp - 2018中国网络安全技术对抗赛阿里安全攻防对抗赛
- 安全之红蓝对抗简介
- 实战生成对抗网络(三):DCGAN
- 什么是生成对抗网络 (GAN)
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Data Structures and Algorithms
Alfred V. Aho、Jeffrey D. Ullman、John E. Hopcroft / Addison Wesley / 1983-1-11 / USD 74.20
The authors' treatment of data structures in Data Structures and Algorithms is unified by an informal notion of "abstract data types," allowing readers to compare different implementations of the same......一起来看看 《Data Structures and Algorithms》 这本书的介绍吧!