内容简介:最近在推进项目的时候,遇到需要将线下的Python代码转化成线上的集群代码,由于机器代码环境是Scala,所以需要将代码翻译一遍,遇到的最大问题就是数学函数。由于Python环境下有很多强大的数学统计和分析的包,所以基本上很多数学函数可在Python包中找到,而Scala相对就弱一些。本次遇到的困难主要是以下三个数学函数:该函数返回的是gamma 函数的对数,即gammaln(A) = log(gamma(A)) 。输入 A 必须是非负数和实数。gammaln 命令可避免直接使用 log(gamma(A)
最近在推进项目的时候,遇到需要将线下的 Python 代码转化成线上的集群代码,由于机器代码环境是Scala,所以需要将代码翻译一遍,遇到的最大问题就是数学函数。由于Python环境下有很多强大的数学统计和分析的包,所以基本上很多数学函数可在Python包中找到,而Scala相对就弱一些。本次遇到的困难主要是以下三个数学函数:
from scipy.special import gammaln from scipy.special import hyp2f1 from scipy.special import logsumexp
gamma 函数的对数
该函数返回的是gamma 函数的对数,即gammaln(A) = log(gamma(A)) 。输入 A 必须是非负数和实数。gammaln 命令可避免直接使用 log(gamma(A)) 计算时可能会出现的下溢和上溢。
在scipy中其适用的是 Cpython实现 ,基于此代码,可以有的解决方案为:
- 抽离相关的 C语言 代码,将其编程成.so文件,在Scala中直接调用.so文件中的函数
- 阅读相关的C语言代码,理解其逻辑,将其转化成Scala代码。
除了上述解决方式外,还可以寻找是否有 现成的Scala代码 :
import scala.math import scala.annotation.tailrec import java.lang.Integer // Adapted from http://www.johndcook.com/stand_alone_code.html // All bugs are however likely my fault class Gamma { //Entry points def gamma(x:Double): Double = { val v = hoboTrampoline(x,false,((y: Double) => y)) v } def logGamma(x:Double): Double = { val v = hoboTrampoline(x,true,((y: Double) => y)) v } //Since scala doesn't support optimizing co-recursive tail-calls //we manually make a trampoline and make it tail recursive @tailrec private def hoboTrampoline(x: Double, log: Boolean,todo: Double => Double): Double = { if (!log) { if (x <= 0.0) { val msg = "Invalid input argument "+x+". Argument must be positive." throw new IllegalArgumentException(msg); } // Split the function domain into three intervals: // (0, 0.001), [0.001, 12), and (12, infinity) /////////////////////////////////////////////////////////////////////////// // First interval: (0, 0.001) // // For small x, 1/Gamma(x) has power series x + gamma x^2 - ... // So in this range, 1/Gamma(x) = x + gamma x^2 with error on the order of x^3. // The relative error over this interval is less than 6e-7. val gamma: Double = 0.577215664901532860606512090; // Euler's gamma constant if (x < 0.001) { todo(1.0/(x*(1.0 + gamma*x))); } else if (x < 12.0) { /////////////////////////////////////////////////////////////////////////// // Second interval: [0.001, 12) // The algorithm directly approximates gamma over (1,2) and uses // reduction identities to reduce other arguments to this interval. val arg_was_less_than_one: Boolean = (x < 1.0); // Add or subtract integers as necessary to bring y into (1,2) // Will correct for this below val (n: Integer,y: Double) = if (arg_was_less_than_one) { (0,x + 1.0) } else { val n: Integer = x.floor.toInt - 1; (n,x-n) } // numerator coefficients for approximation over the interval (1,2) val p: Array[Double] = Array( -1.71618513886549492533811E+0, 2.47656508055759199108314E+1, -3.79804256470945635097577E+2, 6.29331155312818442661052E+2, 8.66966202790413211295064E+2, -3.14512729688483675254357E+4, -3.61444134186911729807069E+4, 6.64561438202405440627855E+4 ); // denominator coefficients for approximation over the interval (1,2) val q: Array[Double] = Array( -3.08402300119738975254353E+1, 3.15350626979604161529144E+2, -1.01515636749021914166146E+3, -3.10777167157231109440444E+3, 2.25381184209801510330112E+4, 4.75584627752788110767815E+3, -1.34659959864969306392456E+5, -1.15132259675553483497211E+5 ); val z: Double = y - 1; val num = p.foldLeft(0: Double)({(a,b) => (b+a)*z}) val den = q.foldLeft(1: Double)({(a,b) => a*z+b}) val result = num/den + 1.0; // Apply correction if argument was not initially in (1,2) if (arg_was_less_than_one) { // Use identity gamma(z) = gamma(z+1)/z // The variable "result" now holds gamma of the original y + 1 // Thus we use y-1 to get back the orginal y. todo(result / (y-1.0)); } else { // Use the identity gamma(z+n) = z*(z+1)* ... *(z+n-1)*gamma(z) todo(List.range(0,n.toInt).map(_.toDouble).foldLeft(result)((a,b) => a*(y+b))) } } else if (x <= 171.624) { /////////////////////////////////////////////////////////////////////////// // Third interval: [12, 171.624) hoboTrampoline(x,true,((a: Double) => todo(math.exp(a)))); } else { /////////////////////////////////////////////////////////////////////////// // Fourth interval: [171.624, INFINITY) // Correct answer too large to display. todo(scala.Double.PositiveInfinity) } } else { //log implementation if (x <= 0.0) { val msg = "Invalid input argument "+x+". Argument must be positive." throw new IllegalArgumentException(msg); } if (x < 12.0) { hoboTrampoline(x,false,((a: Double) => todo(math.log(math.abs(a))))); } else { // Abramowitz and Stegun 6.1.41 // Asymptotic series should be good to at least 11 or 12 figures // For error analysis, see Whittiker and Watson // A Course in Modern Analysis (1927), page 252 val c: Array[Double] = Array( 1.0/12.0, -1.0/360.0, 1.0/1260.0, -1.0/1680.0, 1.0/1188.0, -691.0/360360.0, 1.0/156.0 ); val z: Double = 1.0/(x*x); val sum: Double = c.foldRight(-3617.0/122400.0: Double)({(a,b) => b*z+a}); val series: Double = sum/x; val halfLogTwoPi: Double = 0.91893853320467274178032973640562; val logGamma: Double = (x - 0.5)*math.log(x) - x + halfLogTwoPi + series; todo(logGamma); } } } } object Gamma extends Gamma
高斯超几何函数2F1(a,b,c,d)
hyp2f1 是Scipy中 高斯超几何函数 的实现。找到了 C语言出处 ,同样也找到了 Scala实现 :
def hyp2f1 (a: Double, b: Double, c: Double, z: Double): Double = { val MAX_ITER = 35 // for 9 sig-digits in t-dist with 10 dof (a, b, c) match { case ( _, 1.0, 1.0) => (1.0 / (1.0 - z))~^(-a) case (0.5, 0.5, 1.5) => asin (z) / z case (1.0, 1.0, 2.0) => log (1.0 - z) / -z case (1.0, 2.0, 1.0) => 1.0 / ((1.0 - z) * (1.0 - z)) case (1.0, 2.0, 2.0) => 1.0 / (1.0 - z) case _ => var sum = 0.0 var prod = 1.0 for (k <- 0 until MAX_ITER) { sum += prod prod *= z * ((a + k) * (b + k)) / ((c + k) * (k + 1.0)) } // for sum } // match } // hyp2f1
上述代码编译时会出错,主要原因是~^运算符不正确,正常情况下~为按位取反,^为按位异或。在这里很难解释清楚。查询了该项目,发现~^是被重新定义的为了 求幂运算符 。类似python中的**。类似实现:
implicit class PowerOp[T: Numeric](value: T) { import Numeric.Implicits._ import scala.math.pow def **(power: T): Double = pow(value.toDouble(), power.toDouble()) }
指数函数的和的对数logsumexp
在讲解logsumexp函数之前我们先要了解下这个函数是用来做什么的。假设我们有N个实数 ,我们想要求如下公式:
如果很大或很小,直接计算可能会上溢出或下溢出,从而导致严重问题。举个例子,对于[0 1 0],直接计算是可行的,我们可以得到1.55。但对于[1000 1001 1000],却并不可行,我们会得到inf;对于[-1000,-999,-1000],还是不行,我们会得到-inf。导致此问题的原因是因为浮点数只有64位,在计算指数函数的环节exp{1000}会发生上溢出,计算exp(-1000)时会发生下溢出。即便在数学世界上式的值显然不是无穷大,但在计算机的浮点数世界里就是求不出来。解决方案很简单:
对任意a都成立,这意味着我们可以自由地调节指数函数的指数部分,一个典型的做法是取 中的最大值。Python实现:
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): """Compute the log of the sum of exponentials of input elements. Parameters ---------- a : array_like Input array. axis : None or int or tuple of ints, optional Axis or axes over which the sum is taken. By default `axis` is None, and all elements are summed. .. versionadded:: 0.11.0 keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array. .. versionadded:: 0.15.0 b : array-like, optional Scaling factor for exp(`a`) must be of the same shape as `a` or broadcastable to `a`. These values may be negative in order to implement subtraction. .. versionadded:: 0.12.0 return_sign : bool, optional If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information). .. versionadded:: 0.16.0 Returns ------- res : ndarray The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))`` is returned. sgn : ndarray If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned. See Also -------- numpy.logaddexp, numpy.logaddexp2 Notes ----- Numpy has a logaddexp function which is very similar to `logsumexp`, but only handles two arguments. `logaddexp.reduce` is similar to this function, but may be less stable. Examples -------- >>> from scipy.special import logsumexp >>> a = np.arange(10) >>> np.log(np.sum(np.exp(a))) 9.4586297444267107 >>> logsumexp(a) 9.4586297444267107 With weights >>> a = np.arange(10) >>> b = np.arange(10, 0, -1) >>> logsumexp(a, b=b) 9.9170178533034665 >>> np.log(np.sum(b*np.exp(a))) 9.9170178533034647 Returning a sign flag >>> logsumexp([1,2],b=[1,-1],return_sign=True) (1.5413248546129181, -1.0) Notice that `logsumexp` does not directly support masked arrays. To use it on a masked array, convert the mask into zero weights: >>> a = np.ma.array([np.log(2), 2, np.log(3)], ... mask=[False, True, False]) >>> b = (~a.mask).astype(int) >>> logsumexp(a.data, b=b), np.log(5) 1.6094379124341005, 1.6094379124341005 """ a = _asarray_validated(a, check_finite=False) if b is not None: a, b = np.broadcast_arrays(a, b) if np.any(b == 0): a = a + 0. # promote to at least float a[b == 0] = -np.inf a_max = np.amax(a, axis=axis, keepdims=True) if a_max.ndim > 0: a_max[~np.isfinite(a_max)] = 0 elif not np.isfinite(a_max): a_max = 0 if b is not None: b = np.asarray(b) tmp = b * np.exp(a - a_max) else: tmp = np.exp(a - a_max) # suppress warnings about log of zero with np.errstate(divide='ignore'): s = np.sum(tmp, axis=axis, keepdims=keepdims) if return_sign: sgn = np.sign(s) s *= sgn # /= makes more sense but we need zero -> zero out = np.log(s) if not keepdims: a_max = np.squeeze(a_max, axis=axis) out += a_max if return_sign: return out, sgn else: return out
由于逻辑比较简单,直接翻译为Scala即可,这里不做详细解释了。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:- 函数式编程里面的基本工具函数实现
- algorithm – 给定exp()函数,如何实现ln()函数?
- MySQL排名函数实现
- JavaScript实现函数重载
- C++实现成员函数检查
- 使用函数式实现命令模式
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。