谷歌开源 TensorFlow 的简化库 JAX

栏目: 数据库 · 发布时间: 5年前

内容简介:谷歌开源了一个 TensorFlow 的简化库 JAX。JAX 结合了 Autograd 和 XLA,专门用于高性能机器学习研究。

谷歌开源了一个 TensorFlow 的简化库 JAX。

谷歌开源 TensorFlow 的简化库 JAX

JAX 结合了 Autograd 和 XLA,专门用于高性能机器学习研究。

凭借 Autograd,JAX 可以求导循环、分支、递归和闭包函数,并且它可以进行三阶求导。通过 grad,它支持自动模式反向求导(反向传播)和正向求导,且二者可以任何顺序任意组合。

得力于 XLA,可以在 GPU 和 TPU 上编译和运行 NumPy 程序。默认情况下,编译发生在底层,库调用实时编译和执行。但是 JAX 还允许使用单一函数 API jit 将 Python 函数及时编译为 XLA 优化的内核。编译和自动求导可以任意组合,因此可以在 Python 环境下实现复杂的算法并获得最大的性能。

demo:

import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial

def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

更深入地看,JAX 实际上是一个可扩展的可组合函数转换系统,grad 和 jit 都是这种转换的实例。

项目地址:https://github.com/google/JAX


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

引爆点

引爆点

[美] 马尔科姆·格拉德威尔 / 钱清、覃爱冬 / 中信出版社 / 2009-8 / 27.00元

我们的世界看上去很坚固,但在《纽约客》怪才格拉德威尔的眼里,只要你找到那个点,轻轻一触,这个世界就会动起来:一位满意而归的顾客能让新开张的餐馆座无虚席,一位涂鸦爱好者能在地铁掀起犯罪浪潮,一位精明小伙传递的信息拉开了美国独立战争的序幕——这个看起来不起眼的点,却是任何人都不能忽视的引爆点。 《引爆点》是一本谈论怎样让产品发起流行潮的专门性著作。书中将产品爆发流行的现象归因为三种模式:个别人物......一起来看看 《引爆点》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

Base64 编码/解码
Base64 编码/解码

Base64 编码/解码

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试