GAN入门实践(一)--Tensorflow实现

栏目: 数据库 · 发布时间: 6年前

内容简介:最近在网上看到一个据说是 Alex Smola 写的关于这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

最近在网上看到一个据说是 Alex Smola 写的关于 生成对抗网络 (Generative Adversarial Network, GAN)的入门教程,目的是从实践的角度讲解 GAN 的基本思想和实现过程。这个教程没有采用最普遍的用 GAN 生成图像作为例子,而是用生成一个特定的高斯分布作为切入点来讲解 GAN 的工作原理和效果。对 GAN 感兴趣的朋友不妨去看看这个简短的教程,直观易懂而不失有趣,链接如下:

https://github.com/zackchase/mxnet-the-straight-dope/blob/master/P10-C01-gan-intro.ipynb

这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

Tensorflow实现的全代码在 这里 ,另外还有conditional GAN的实现在 这里 。 其中还参考了 DCGAN的实现

Pytorch篇请看这里。

模型定义

class GAN(object):
    def __init__(self):     
        # input, output
        self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')
        self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')
        
        # define the network
        self.fake_x = self.netG(self.z)
        self.real_logits = self.netD(self.x, reuse=False)
        self.fake_logits = self.netD(self.fake_x, reuse=True)
        
        ......
        
    def netD(self, x, reuse=False):
        """3-layer fully connected network"""
        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()
            
            W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                trainable=True) 
        ......

Tensorflow 中对于变量(Variable)的创建有个需要特别主要的地方就是 变量域 (variable_scope)。当我们需要多次调用同一个网络(如上面的例子),或者多个网络共同分享其变量(比如权值共享)时,我们就可以把它们设成在同样的 variable_scope 下(当然变量名 name 也要是一样的),然后选择 reuse = True 实现变量的复用和共享。上面的代码就是一个例子,我们的 netD 只有一个,但是需要调用两次得到 real_logits 和 fake_logits。这时,我们第一次创建 netD 应该用 reuse=False,因为是创建一个新的网络和变量,而第二次调用时则选择 reuse=True,从而保证网络的变量是一样的,即用的同一个网络。

定义优化

class GAN(object):
    def __init__(self):
        
        ......
        
        # collect variables
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
        
        ......

gan = GAN()
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)

生成对抗网络的优化过程是一个生成器 G 和判别器 D 相互搏弈的过程,因此我们对 G 和 D 分别定义其对应的优化方法。它们分别对应不同的损失函数和不同的变量,其中我们运用变量命名的小技巧将生成器和判别器的变量分别存到不同的列表当中。

模型训练

这个部分就很常规了。注意以下几点就ok了:

  • tensorflow中要先做初始化
  • 定义好minibatch的生成方式
  • 每次迭代要生成一个随机数向量 z_batch

跟教程中的实现类似,我们在每个epoch结束后都将判别器的分类正确率打印出来看看。由于 GAN 的训练中,loss并不能作为很好的监测训练过程的标准,打印分类accuracy可以让我们比较直观地看到模型的训练效果(我们最终希望生成的数据能够fool判别器,所以希望准确率为50%)。


以上所述就是小编给大家介绍的《GAN入门实践(一)--Tensorflow实现》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Web前端开发最佳实践

Web前端开发最佳实践

党建 / 机械工业出版社 / 2015-1 / 59.00元

本书贴近Web前端标准来介绍前端开发相关最佳实践,目的在于让前端开发工程师提高编写代码的质量,重视代码的可维护性和执行性能,让初级工程师从入门开始就养成一个良好的编码习惯。本书总共分五个部分13章,第一部分包括第1章和第2章,介绍前端开发的基本范畴和现状,并综合介绍前端开发的一些最佳实践;第二部分为第3-5章,讲解HTML相关的最佳实践,并简单介绍HTML5中新标签的使用;第三部分为第6-8章,介......一起来看看 《Web前端开发最佳实践》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

html转js在线工具
html转js在线工具

html转js在线工具