巧断梯度:单个loss实现GAN模型(附开源代码)

栏目: 数据库 · 发布时间: 6年前

内容简介:我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的,

我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。

现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。

如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的, 这也就是 GAN 的同步训练。

注:本文不是介绍新的 GAN,而是介绍 GAN 的新写法,这只是一道编程题,不是一道算法题。

如果在TF中

如果是在 TensorFlow 中,实现同步训练并不困难,因为我们定义好了判别器和生成器的训练算子了(假设为 D_solver 和 G_solver ),那么直接执行:

sess.run([D_solver, G_solver], feed_dict={x_in: x_train, z_in: z_train})

就行了。这建立在我们能分别获取判别器和生成器的参数、能直接操作 sess.run 的基础上。

更通用的方法

但是如果是 Keras 呢?Keras 中已经把流程封装好了,一般来说我们没法去操作得如此精细。

所以,下面我们介绍一个通用的技巧, 只需要定义单一一个 loss,然后扔给优化器,就能够实现 GAN 的训练。 同时,从这个技巧中,我们还可以学习到如何更加灵活地操作 loss 来控制梯度。

判别器的优化

我们以 GAN 的 hinge loss 为例子,它的形式是:

巧断梯度:单个loss实现GAN模型(附开源代码)

注意 巧断梯度:单个loss实现GAN模型(附开源代码) 意味着要固定 G,因为 G 本身也是有优化参数的,不固定的话就应该是 巧断梯度:单个loss实现GAN模型(附开源代码)

为了固定G,除了“把 G 的参数从优化器中去掉”这个方法之外,我们也可以利用 stop_gradient去手动固定:

巧断梯度:单个loss实现GAN模型(附开源代码)

这里:

巧断梯度:单个loss实现GAN模型(附开源代码)

这样一来,在式 (2) 中,我们虽然同时放开了 D,G 的权重,但是不断地优化式 (2),会变的只有 D,而 G 是不会变的,因为我们用的是基于梯度下降的优化器,而 G 的梯度已经被停止了,换句话说,我们可以理解为 G 的梯度被强行设置为 0,所以它的更新量一直都是 0。

生成器的优化

现在解决了 D 的优化,那么 G 呢? stop_gradient 可以很方便地放我们固定里边部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的优化是要我们去固定外边的 D,没有函数实现它。但不要灰心,我们可以用一个数学技巧进行转化。

首先,我们要清楚,我们想要 D(G(z)) 里边的 G 的梯度,不想要 D 的梯度,如果直接对 D(G(z)) 求梯度,那么同时会得到 D,G 的梯度。如果直接求 巧断梯度:单个loss实现GAN模型(附开源代码) 的梯度呢?只能得到 D 的梯度,因为 G 已经被停止了。那么,重点来了,将这两个相减,不就得到单纯的 G 的梯度了吗!

巧断梯度:单个loss实现GAN模型(附开源代码)

现在优化式 (4) ,那么 D 是不会变的,改变的是 G。

值得一提的是,直接输出这个式子,结果是恒等于 0,因为两部分都是一样的,直接相减自然是 0,但它的梯度不是 0。也就是说,这是一个恒等于 0 的 loss,但是梯度却不恒等于 0。

合成单一loss 

好了,现在式 (2) 和式 (4) 都同时放开了 D,G,大家都是 arg min,所以可以将两步合成一个 loss:

巧断梯度:单个loss实现GAN模型(附开源代码)

写出这个 loss,就可以同时完成判别器和生成器的优化了,而不需要交替训练,但是效果基本上等效于 1:1 的交替训练。引入 λ 的作用,相当于让判别器和生成器的学习率之比为 1:λ。

参考代码:

https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py

文章小结

文章主要介绍了实现 GAN 的一个小技巧,允许我们只写单个模型、用单个 loss 就实现 GAN 的训练。它本质上就是用 stop_gradient 来手动控制梯度的技巧,在其他任务上也可能用得到它。

所以,以后我写 GAN 都用这种写法了,省力省时。当然,理论上这种写法需要多耗些显存,这也算是牺牲空间换时间吧。


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

计算机程序设计艺术:第4卷 第4册(双语版)

计算机程序设计艺术:第4卷 第4册(双语版)

Donald E.Knuth / 苏运霖 / 机械工业出版社 / 2007-4 / 42.00元

关于算法分析的这多卷论著已经长期被公认为经典计算机科学的定义性描述。迄今已出版的完整的三卷组成了程序设计理论和实践的惟一的珍贵源泉,无数读者都赞扬Knuth的著作对个人的深远影响。科学家们为他的分析的美丽和优雅所惊叹,而从事实践的程序员们已经成功地应用他的“菜谱式”的解到日常问题上,所有人都由于Knuth在书中所表现出的博学、清晰、精确和高度幽默而对他无比敬仰。   为开始后续各卷的写作并更......一起来看看 《计算机程序设计艺术:第4卷 第4册(双语版)》 这本书的介绍吧!

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具