GAN入门实践(一)--Tensorflow实现

栏目: 数据库 · 发布时间: 6年前

内容简介:最近在网上看到一个据说是 Alex Smola 写的关于这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

最近在网上看到一个据说是 Alex Smola 写的关于 生成对抗网络 (Generative Adversarial Network, GAN)的入门教程,目的是从实践的角度讲解 GAN 的基本思想和实现过程。这个教程没有采用最普遍的用 GAN 生成图像作为例子,而是用生成一个特定的高斯分布作为切入点来讲解 GAN 的工作原理和效果。对 GAN 感兴趣的朋友不妨去看看这个简短的教程,直观易懂而不失有趣,链接如下:

https://github.com/zackchase/mxnet-the-straight-dope/blob/master/P10-C01-gan-intro.ipynb

这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

Tensorflow实现的全代码在 这里 ,另外还有conditional GAN的实现在 这里 。 其中还参考了 DCGAN的实现

Pytorch篇请看这里。

模型定义

class GAN(object):
    def __init__(self):     
        # input, output
        self.z = tf.placeholder(tf.float32, shape=[None, 2], name='z')
        self.x = tf.placeholder(tf.float32, shape=[None, 2], name='real_x')
        
        # define the network
        self.fake_x = self.netG(self.z)
        self.real_logits = self.netD(self.x, reuse=False)
        self.fake_logits = self.netD(self.fake_x, reuse=True)
        
        ......
        
    def netD(self, x, reuse=False):
        """3-layer fully connected network"""
        with tf.variable_scope("discriminator") as scope:
            if reuse:
                scope.reuse_variables()
            
            W1 = tf.get_variable(name="d_W1", shape=[2, 5],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                trainable=True) 
        ......

Tensorflow 中对于变量(Variable)的创建有个需要特别主要的地方就是 变量域 (variable_scope)。当我们需要多次调用同一个网络(如上面的例子),或者多个网络共同分享其变量(比如权值共享)时,我们就可以把它们设成在同样的 variable_scope 下(当然变量名 name 也要是一样的),然后选择 reuse = True 实现变量的复用和共享。上面的代码就是一个例子,我们的 netD 只有一个,但是需要调用两次得到 real_logits 和 fake_logits。这时,我们第一次创建 netD 应该用 reuse=False,因为是创建一个新的网络和变量,而第二次调用时则选择 reuse=True,从而保证网络的变量是一样的,即用的同一个网络。

定义优化

class GAN(object):
    def __init__(self):
        
        ......
        
        # collect variables
        t_vars = tf.trainable_variables()
        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
        
        ......

gan = GAN()
d_optim = tf.train.AdamOptimizer(learning_rate=0.05).minimize(gan.loss_D, var_list=gan.d_vars)
g_optim = tf.train.AdamOptimizer(learning_rate=0.01).minimize(gan.loss_G, var_list=gan.g_vars)

生成对抗网络的优化过程是一个生成器 G 和判别器 D 相互搏弈的过程,因此我们对 G 和 D 分别定义其对应的优化方法。它们分别对应不同的损失函数和不同的变量,其中我们运用变量命名的小技巧将生成器和判别器的变量分别存到不同的列表当中。

模型训练

这个部分就很常规了。注意以下几点就ok了:

  • tensorflow中要先做初始化
  • 定义好minibatch的生成方式
  • 每次迭代要生成一个随机数向量 z_batch

跟教程中的实现类似,我们在每个epoch结束后都将判别器的分类正确率打印出来看看。由于 GAN 的训练中,loss并不能作为很好的监测训练过程的标准,打印分类accuracy可以让我们比较直观地看到模型的训练效果(我们最终希望生成的数据能够fool判别器,所以希望准确率为50%)。


以上所述就是小编给大家介绍的《GAN入门实践(一)--Tensorflow实现》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Docker从入门到实战

Docker从入门到实战

黄靖钧 / 机械工业出版社 / 2017-6 / 69.00元

本书从Docker的相关概念与基础知识讲起,结合实际应用,通过不同开发环境的实战例子,详细介绍了Docker的基础知识与进阶实战的相关内容,以引领读者快速入门并提高。 本书共19章,分3篇。第1篇容器技术与Docker概念,涵盖的内容有容器技术、Docker简介、安装Docker等。第2篇Docker基础知识,涵盖的内容有Docker基础、Docker镜像、Dockerfile文件、Dock......一起来看看 《Docker从入门到实战》 这本书的介绍吧!

RGB转16进制工具
RGB转16进制工具

RGB HEX 互转工具

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具