GAN入门实践(一)--Tensorflow实现

栏目: 数据库 · 发布时间: 6年前

内容简介:最近在网上看到一个据说是 Alex Smola 写的关于这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

最近在网上看到一个据说是 Alex Smola 写的关于 生成对抗网络 (Generative Adversarial Network, GAN)的入门教程,目的是从实践的角度讲解 GAN 的基本思想和实现过程。这个教程没有采用最普遍的用 GAN 生成图像作为例子,而是用生成一个特定的高斯分布作为切入点来讲解 GAN 的工作原理和效果。对 GAN 感兴趣的朋友不妨去看看这个简短的教程,直观易懂而不失有趣,链接如下:

https://github.com/zackchase/mxnet-the-straight-dope/blob/master/P10-C01-gan-intro.ipynb

这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

Tensorflow实现的全代码在 这里 ,另外还有conditional GAN的实现在 这里 。 其中还参考了 DCGAN的实现

Pytorch篇请看这里。

模型定义

class GAN(object):
    def __init__(self):     
        # input, output
        self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')
        self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')
        
        # define the network
        self.fake_x = self.netG(self.z)
        self.real_logits = self.netD(self.x, reuse=False)
        self.fake_logits = self.netD(self.fake_x, reuse=True)
        
        ......
        
    def netD(self, x, reuse=False):
        """3-layer fully connected network"""
        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()
            
            W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                trainable=True) 
        ......

Tensorflow 中对于变量(Variable)的创建有个需要特别主要的地方就是 变量域 (variable_scope)。当我们需要多次调用同一个网络(如上面的例子),或者多个网络共同分享其变量(比如权值共享)时,我们就可以把它们设成在同样的 variable_scope 下(当然变量名 name 也要是一样的),然后选择 reuse = True 实现变量的复用和共享。上面的代码就是一个例子,我们的 netD 只有一个,但是需要调用两次得到 real_logits 和 fake_logits。这时,我们第一次创建 netD 应该用 reuse=False,因为是创建一个新的网络和变量,而第二次调用时则选择 reuse=True,从而保证网络的变量是一样的,即用的同一个网络。

定义优化

class GAN(object):
    def __init__(self):
        
        ......
        
        # collect variables
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
        
        ......

gan = GAN()
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)

生成对抗网络的优化过程是一个生成器 G 和判别器 D 相互搏弈的过程,因此我们对 G 和 D 分别定义其对应的优化方法。它们分别对应不同的损失函数和不同的变量,其中我们运用变量命名的小技巧将生成器和判别器的变量分别存到不同的列表当中。

模型训练

这个部分就很常规了。注意以下几点就ok了:

  • tensorflow中要先做初始化
  • 定义好minibatch的生成方式
  • 每次迭代要生成一个随机数向量 z_batch

跟教程中的实现类似,我们在每个epoch结束后都将判别器的分类正确率打印出来看看。由于 GAN 的训练中,loss并不能作为很好的监测训练过程的标准,打印分类accuracy可以让我们比较直观地看到模型的训练效果(我们最终希望生成的数据能够fool判别器,所以希望准确率为50%)。


以上所述就是小编给大家介绍的《GAN入门实践(一)--Tensorflow实现》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

大数据预测

大数据预测

【美】埃里克·西格尔 / 周昕 / 中信出版社 / 2014-3 / 58.00

360公司董事长周鸿祎、《罗辑思维》主讲人罗振宇郑重推荐 2020年的一天,在你驱车前往公司的路上,导航系统通过预测交通流量,会自动帮你选择一条最合适的交通路线;车内推荐系统会根据你的饮食习惯预测你可能会喜欢吃什么,并推荐沿途的早餐店;你的电子社交助理已经为你自动选择了你可能感兴趣的社交网信息;当车内系统预测到你驾车有些分心时,座椅会自动震动进行提醒…… 以上这些情景不是科幻大片独有的......一起来看看 《大数据预测》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

随机密码生成器
随机密码生成器

多种字符组合密码

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码