GAN入门实践(一)--Tensorflow实现

栏目: 数据库 · 发布时间: 6年前

内容简介:最近在网上看到一个据说是 Alex Smola 写的关于这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

最近在网上看到一个据说是 Alex Smola 写的关于 生成对抗网络 (Generative Adversarial Network, GAN)的入门教程,目的是从实践的角度讲解 GAN 的基本思想和实现过程。这个教程没有采用最普遍的用 GAN 生成图像作为例子,而是用生成一个特定的高斯分布作为切入点来讲解 GAN 的工作原理和效果。对 GAN 感兴趣的朋友不妨去看看这个简短的教程,直观易懂而不失有趣,链接如下:

https://github.com/zackchase/mxnet-the-straight-dope/blob/master/P10-C01-gan-intro.ipynb

这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

Tensorflow实现的全代码在 这里 ,另外还有conditional GAN的实现在 这里 。 其中还参考了 DCGAN的实现

Pytorch篇请看这里。

模型定义

class GAN(object):
    def __init__(self):     
        # input, output
        self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')
        self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')
        
        # define the network
        self.fake_x = self.netG(self.z)
        self.real_logits = self.netD(self.x, reuse=False)
        self.fake_logits = self.netD(self.fake_x, reuse=True)
        
        ......
        
    def netD(self, x, reuse=False):
        """3-layer fully connected network"""
        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()
            
            W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                trainable=True) 
        ......

Tensorflow 中对于变量(Variable)的创建有个需要特别主要的地方就是 变量域 (variable_scope)。当我们需要多次调用同一个网络(如上面的例子),或者多个网络共同分享其变量(比如权值共享)时,我们就可以把它们设成在同样的 variable_scope 下(当然变量名 name 也要是一样的),然后选择 reuse = True 实现变量的复用和共享。上面的代码就是一个例子,我们的 netD 只有一个,但是需要调用两次得到 real_logits 和 fake_logits。这时,我们第一次创建 netD 应该用 reuse=False,因为是创建一个新的网络和变量,而第二次调用时则选择 reuse=True,从而保证网络的变量是一样的,即用的同一个网络。

定义优化

class GAN(object):
    def __init__(self):
        
        ......
        
        # collect variables
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
        
        ......

gan = GAN()
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)

生成对抗网络的优化过程是一个生成器 G 和判别器 D 相互搏弈的过程,因此我们对 G 和 D 分别定义其对应的优化方法。它们分别对应不同的损失函数和不同的变量,其中我们运用变量命名的小技巧将生成器和判别器的变量分别存到不同的列表当中。

模型训练

这个部分就很常规了。注意以下几点就ok了:

  • tensorflow中要先做初始化
  • 定义好minibatch的生成方式
  • 每次迭代要生成一个随机数向量 z_batch

跟教程中的实现类似,我们在每个epoch结束后都将判别器的分类正确率打印出来看看。由于 GAN 的训练中,loss并不能作为很好的监测训练过程的标准,打印分类accuracy可以让我们比较直观地看到模型的训练效果(我们最终希望生成的数据能够fool判别器,所以希望准确率为50%)。


以上所述就是小编给大家介绍的《GAN入门实践(一)--Tensorflow实现》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

共享经济

共享经济

[美] 罗宾•蔡斯 / 王芮 / 浙江人民出版社 / 2015-9-25 / 59.90元

[内容简介]  在当今这个稀缺的世界里,人人共享组织可以创造出富足。通过利用已有的资源,如有形资产、技术、网络、设备、数据、经验和流程等,这些组织可以以指数级成长。人人共享重新定义了我们对于资产的理解:它是专属于个人的还是大众的;是私有的还是公有的;是商业的还是个人的,并且也让我们对监管、保险以及管理有了重新的思索。  在这本书中,罗宾与大家分享了以下观点:  如何利用过剩......一起来看看 《共享经济》 这本书的介绍吧!

随机密码生成器
随机密码生成器

多种字符组合密码

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具