实战生成对抗网络[2]:生成手写数字

栏目: 数据库 · 发布时间: 5年前

内容简介:在开始本文之前,让我们先看看一则报道:人民网讯 据英国广播电视公司10月25日报道,由人工智能创作的艺术作品以432000美元(约合300万人民币)的高价成功拍卖。看起来一则不起眼的新闻,其实意义深远,它意味着人们开始认可计算机创作的艺术价值,那些沾沾自喜认为不会被人工智能取代的艺术家也要瑟瑟发抖了。

在开始本文之前,让我们先看看一则报道:

人民网讯 据英国广播电视公司10月25日报道,由人工智能创作的艺术作品以432000美元(约合300万人民币)的高价成功拍卖。

看起来一则不起眼的新闻,其实意义深远,它意味着人们开始认可计算机创作的艺术价值,那些沾沾自喜认为不会被人工智能取代的艺术家也要瑟瑟发抖了。

这幅由人工智能创作的作品长啥样,有啥过人之处?

实战生成对抗网络[2]:生成手写数字

嗯,以我这种外行人士看来,实在不怎么样,但这不意味着人工智能不行。要知道,AlphaGo初出道时,也只敢挑战一下樊麾这样的二流棋手,接下来挑战顶级棋手李世石,人类还能勉力一战,等进化到AlphaGo Master,零封人类棋手。然而这还没有完,AlphaGo Zero不再学习人类棋譜,完全通过自学,碾压AlphaGo Master,对付人类棋手,更如我们捏死一只蚂蚁那么容易。

所以说,尽管人工智能创作的第一副作品如同鬼画桃符,但其潜力无可限量。

那么,接下来我们会探讨如何创作出一幅名画?No. No.

创作一副画并不是那么容易。这幅名为《埃德蒙·贝拉米肖像》的画作是由巴黎一个名为“显而易见”(Obvious)的艺术团体创作利用人工智能技术创作而成,这幅作品是用算法和15000幅从14世纪到20世纪的肖像画数据制作而成。

我们还没有那个条件去创作一副人工智能的画作,但我们可以先从基本的着手,生成手写数字。手写数字对于机器学习的同学来说,太熟悉不过了。既然是老朋友了,那让我们开始吧!

首先回顾一下《 实战生成对抗网络[1]:简介 》这篇文章的内容,GAN由生成器和判别器组成。简单起见,我们选择简单的二层神经网络来实现生成器和判别器。

生成器

实现生成器并不难,我们采取的全连接网络拓扑结构为:100 --> 128 --> 784,最后的输出为784是因为MNIST数据集就是由28 x 28像素的灰度图像组成。代码如下:

G_W1 = tf.Variable(initializer([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(initializer([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]

def generator(z):
  G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
  G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
  G_prob = tf.nn.sigmoid(G_log_prob)

  return G_prob
复制代码

判别器

判别器正好相反,以MNIST图像作为输入并返回一个代表真实图像的概率的标量,代码如下:

D_W1 = tf.Variable(initializer(shape=[784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(initializer(shape=[128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name="D_W2")
theta_D = [D_W1, D_W2, D_b1, D_b2]

def discriminator(x):
  D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
  D_logit = tf.matmul(D_h1, D_W2) + D_b2
  D_prob = tf.nn.sigmoid(D_logit)

  return D_prob, D_logit
复制代码

训练算法

在论文 arXiv: 1406.2661, 2014 中给出了训练算法的伪代码:

实战生成对抗网络[2]:生成手写数字

TensorFlow中的优化器只能做最小化,因为为了最大化损失函数,我们在伪代码给出的损失函数前加上一个负号。

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))
复制代码

接下来定义优化器:

# 仅更新D(X)的参数, var_list=theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
# 仅更新G(X)的参数, var_list=theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
复制代码

最后进行迭代,更新参数:

for it in range(60000):
  X_mb, _ = mnist.train.next_batch(mb_size)

  _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
  _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
复制代码

整个流程下来,其实和之前的深度学习算法差不多,非常容易理解。算法是不是有效果呢?我们可以将迭代过程中生成的手写数字显示出来:

实战生成对抗网络[2]:生成手写数字

嗯,结果虽然有点差强人意,但差不多是手写数字的字形,而且随着迭代,越来越接近手写数字,可以说GAN算法还是有效的。

小结

一个简单的GAN网络就这么几行代码就能搞定,看样子生成一副画也没有什么难的。先不要这么乐观,其实,GAN网络中的坑还是不少,比如在迭代过程中,就出现过如下提示:

Iter: 9000
D loss: nan
G_loss: nan
复制代码

从代码中我们可以看出,GAN网络依然采用的梯度下降法来迭代求解参数。梯度下降的启动会选择一个减小所定义问题损失的方向,但是我们并没有一个办法来确保利用GAN网络可以进入纳什均衡的状态,这是一个高维度的非凸优化目标。网络试图在接下来的步骤中最小化非凸优化目标,最终有可能导致进入振荡而不是收敛到底层正式目标。

另外还有模型坍塌、计数、角度以及全局结构方面的问题,要解决这些问题,需要使用一些特殊的技巧和方法,后面我们深入各种GAN模型时将会探讨。

本文完整的代码请参考: github.com/mogoweb/aie…


以上所述就是小编给大家介绍的《实战生成对抗网络[2]:生成手写数字》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

趣学算法

趣学算法

陈小玉 / 人民邮电出版社 / 2017-7-1 / 89.00元

本书内容按照算法策略分为7章。 第1章从算法之美、简单小问题、趣味故事引入算法概念、时间复杂度、空间复杂度的概念和计算方法,以及算法设计的爆炸性增量问题,使读者体验算法的奥妙。 第2~7章介绍经典算法的设计策略、实战演练、算法分析及优化拓展,分别讲解贪心算法、分治算法、动态规划、回溯法、分支限界法、线性规划和网络流。每一种算法都有4~10个实例,共50个大型实例,包括经典的构造实例和实......一起来看看 《趣学算法》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

SHA 加密
SHA 加密

SHA 加密工具

html转js在线工具
html转js在线工具

html转js在线工具