【译】Effective TensorFlow Chapter12——TensorFlow中的数值稳定性

栏目: 数据库 · 发布时间: 5年前

内容简介:本文翻译自:当使用任何数值计算库(如NumPy或TensorFlow)时,值得注意的是,编写出正确的数学计算代码对于计算出正确结果并不是必须的。你同样需要确保整个计算过程是稳定的。让我们从一个例子入手。小学的时候我们就知道,对于任意一个非0的数x,都有
【译】Effective TensorFlow Chapter12——TensorFlow中的数值稳定性

本文翻译自: 《Numerical stability in TensorFlow》 , 如有侵权请联系删除,仅限于学术交流,请勿商用。如有谬误,请联系指出。

当使用任何数值计算库(如NumPy或TensorFlow)时,值得注意的是,编写出正确的数学计算代码对于计算出正确结果并不是必须的。你同样需要确保整个计算过程是稳定的。

让我们从一个例子入手。小学的时候我们就知道,对于任意一个非0的数x,都有 x*y/y=x 。但是让我们在实践中看看是否如此:

import numpy as np

x = np.float32(1)

y = np.float32(1e-50)  # y would be stored as zero
z = x * y / y

print(z)  # prints nan
复制代码

错误的原因是:y是 float32 类型的数字,所能表示的数值太小。当y太大时会出现类似的问题:

y = np.float32(1e39)  # y would be stored as inf
z = x * y / y

print(z)  # prints 0
复制代码

float32类型可以表示的最小正值是1.4013e-45,任何低于该值的数都将存储为零。此外,任何超过3.40282e + 38的数都将存储为inf。

print(np.nextafter(np.float32(0), np.float32(1)))  # prints 1.4013e-45
print(np.finfo(np.float32).max)  # print 3.40282e+38
复制代码

为了保证计算的稳定性,你需要避免使用绝对值非常小或非常大的值。可能听起来这种问题比较低级,但这些问题可能会让程序变得难以调试,尤其是在TensorFlow中进行梯度下降时。这是因为你不仅需要确保正向传递中的所有值都在数据类型的有效范围内,而且反向传播时同样如此(在梯度运算期间)。

让我们看一个真实的例子。我们想要在logits向量上计算其softmax的值。一个too navie的实现方式就像这样:

import tensorflow as tf

def unstable_softmax(logits):
    exp = tf.exp(logits)
    return exp / tf.reduce_sum(exp)

tf.Session().run(unstable_softmax([1000., 0.]))  # prints [ nan, 0.]
复制代码

注意一下,计算相对较小的数的对数,将会得到一个超出float32范围的大数。对于我们的naive softmax实现来说,最大的有效对数是ln(3.40282e+38) = 88.7,如果超过这个值,就会导致nan结果。

但是我们怎样才能使它更稳定呢?解决办法相当简单。很容易看到,exp (x - c) /∑exp (x - c) = exp (x) /∑exp (x)。因此,我们可以从逻辑中减去任何常数,结果还是一样的。我们选择这个常数作为逻辑的最大值。这样,指数函数的定义域将被限制为[-inf, 0],因此其范围将为[0.0,1.0],这是可取的:

import tensorflow as tf

def softmax(logits):
    exp = tf.exp(logits - tf.reduce_max(logits))
    return exp / tf.reduce_sum(exp)

tf.Session().run(softmax([1000., 0.]))  # prints [ 1., 0.]
复制代码

让我们来看一个复杂点案例。假设我们有一个分类问题,并且使用softmax函数从我们的逻辑中产生概率。然后我们定义一个真实值和预测值之间的交叉熵损失函数。回想一下,交叉熵的分类分布可以简单地定义为 xe(p, q) = -∑ p_i log(q_i) ,所以一个简单的交叉熵代码是这样的:

def unstable_softmax_cross_entropy(labels, logits):
    logits = tf.log(softmax(logits))
    return -tf.reduce_sum(labels * logits)

labels = tf.constant([0.5, 0.5])
logits = tf.constant([1000., 0.])

xe = unstable_softmax_cross_entropy(labels, logits)

print(tf.Session().run(xe))  # prints inf
复制代码

请注意,在这个代码中,当softmax输出接近0时,输出将会接近无穷,这将导致我们的计算不稳定。我们可以通过扩展softmax函数并做一些简化来重写它:

def softmax_cross_entropy(labels, logits):
    scaled_logits = logits - tf.reduce_max(logits)
    normalized_logits = scaled_logits - tf.reduce_logsumexp(scaled_logits)
    return -tf.reduce_sum(labels * normalized_logits)

labels = tf.constant([0.5, 0.5])
logits = tf.constant([1000., 0.])

xe = softmax_cross_entropy(labels, logits)

print(tf.Session().run(xe))  # prints 500.0
复制代码

我们也可以验证梯度计算也是正确的:

g = tf.gradients(xe, logits)
print(tf.Session().run(g))  # prints [0.5, -0.5]
复制代码

再次提醒一下,在做梯度下降的时候务必格外小心,以确保函数以及每一层的梯度值都在一个有效的范围内。指数函数和对数函数在使用时也要格外的注意,因为它们可以将小数字映射为大数字,反之亦然。


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Tomcat与Java Web开发技术详解

Tomcat与Java Web开发技术详解

孙卫琴 / 电子工业出版社 / 2004-4-1 / 45.00元

《Tomcat与Java Web开发技术详解》编辑推荐:Jakarta Tomcat服务器是在SUN公司的JSWDK(JavaServer Web DevelopmentKit,SUN公司推出的小型Servlet/JSP调试工具)的基础上发展起来的一个优秀的Java Web应用容器,它是Apache-Jakarta的一个子项目。Tomcat被JavaWorld杂志的编辑选为2001年度最具创新的J......一起来看看 《Tomcat与Java Web开发技术详解》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具