巧断梯度:单个loss实现GAN模型(附开源代码)

栏目: 数据库 · 发布时间: 5年前

内容简介:我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的,

我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。

现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。

如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的, 这也就是 GAN 的同步训练。

注:本文不是介绍新的 GAN,而是介绍 GAN 的新写法,这只是一道编程题,不是一道算法题。

如果在TF中

如果是在 TensorFlow 中,实现同步训练并不困难,因为我们定义好了判别器和生成器的训练算子了(假设为 D_solver 和 G_solver ),那么直接执行:

sess.run([D_solver, G_solver], feed_dict={x_in: x_train, z_in: z_train})

就行了。这建立在我们能分别获取判别器和生成器的参数、能直接操作 sess.run 的基础上。

更通用的方法

但是如果是 Keras 呢?Keras 中已经把流程封装好了,一般来说我们没法去操作得如此精细。

所以,下面我们介绍一个通用的技巧, 只需要定义单一一个 loss,然后扔给优化器,就能够实现 GAN 的训练。 同时,从这个技巧中,我们还可以学习到如何更加灵活地操作 loss 来控制梯度。

判别器的优化

我们以 GAN 的 hinge loss 为例子,它的形式是:

巧断梯度:单个loss实现GAN模型(附开源代码)

注意 巧断梯度:单个loss实现GAN模型(附开源代码) 意味着要固定 G,因为 G 本身也是有优化参数的,不固定的话就应该是 巧断梯度:单个loss实现GAN模型(附开源代码)

为了固定G,除了“把 G 的参数从优化器中去掉”这个方法之外,我们也可以利用 stop_gradient去手动固定:

巧断梯度:单个loss实现GAN模型(附开源代码)

这里:

巧断梯度:单个loss实现GAN模型(附开源代码)

这样一来,在式 (2) 中,我们虽然同时放开了 D,G 的权重,但是不断地优化式 (2),会变的只有 D,而 G 是不会变的,因为我们用的是基于梯度下降的优化器,而 G 的梯度已经被停止了,换句话说,我们可以理解为 G 的梯度被强行设置为 0,所以它的更新量一直都是 0。

生成器的优化

现在解决了 D 的优化,那么 G 呢? stop_gradient 可以很方便地放我们固定里边部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的优化是要我们去固定外边的 D,没有函数实现它。但不要灰心,我们可以用一个数学技巧进行转化。

首先,我们要清楚,我们想要 D(G(z)) 里边的 G 的梯度,不想要 D 的梯度,如果直接对 D(G(z)) 求梯度,那么同时会得到 D,G 的梯度。如果直接求 巧断梯度:单个loss实现GAN模型(附开源代码) 的梯度呢?只能得到 D 的梯度,因为 G 已经被停止了。那么,重点来了,将这两个相减,不就得到单纯的 G 的梯度了吗!

巧断梯度:单个loss实现GAN模型(附开源代码)

现在优化式 (4) ,那么 D 是不会变的,改变的是 G。

值得一提的是,直接输出这个式子,结果是恒等于 0,因为两部分都是一样的,直接相减自然是 0,但它的梯度不是 0。也就是说,这是一个恒等于 0 的 loss,但是梯度却不恒等于 0。

合成单一loss 

好了,现在式 (2) 和式 (4) 都同时放开了 D,G,大家都是 arg min,所以可以将两步合成一个 loss:

巧断梯度:单个loss实现GAN模型(附开源代码)

写出这个 loss,就可以同时完成判别器和生成器的优化了,而不需要交替训练,但是效果基本上等效于 1:1 的交替训练。引入 λ 的作用,相当于让判别器和生成器的学习率之比为 1:λ。

参考代码:

https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py

文章小结

文章主要介绍了实现 GAN 的一个小技巧,允许我们只写单个模型、用单个 loss 就实现 GAN 的训练。它本质上就是用 stop_gradient 来手动控制梯度的技巧,在其他任务上也可能用得到它。

所以,以后我写 GAN 都用这种写法了,省力省时。当然,理论上这种写法需要多耗些显存,这也算是牺牲空间换时间吧。


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

鸟哥的Linux私房菜

鸟哥的Linux私房菜

鸟哥 / 机械工业出版社 / 2008-1 / 88.00元

《鸟哥的Linux私房菜:服务器架设篇(第2版)》是对连续三年蝉联畅销书排行榜前10名的《Linux鸟哥私房菜一服务器架设篇》的升级版,新版本根据目前服务器与网络环境做了大幅度修订与改写。 全书共3部分,第1部分为架站前的进修专区,包括在架设服务器前必须具备的网络基础知识、Linux常用网络命令、Linux网络侦错步骤,以及服务器架站流程:第2部分为主机的简易防火措施,包括限制Linux对......一起来看看 《鸟哥的Linux私房菜》 这本书的介绍吧!

随机密码生成器
随机密码生成器

多种字符组合密码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器

HEX CMYK 转换工具
HEX CMYK 转换工具

HEX CMYK 互转工具