TensorFlow 拆包(九):High Level APIs

栏目: 数据库 · 发布时间: 6年前

内容简介:前篇:这篇来研究一下 TF 中的一些高级 API。TensorFlow 由于一直是在开源的社区环境中发展起来的,早期的一些 API 都比较简单粗暴(更直白地说就是

前篇:

这篇来研究一下 TF 中的一些高级 API。

TensorFlow 由于一直是在开源的社区环境中发展起来的,早期的一些 API 都比较简单粗暴(更直白地说就是 不那么好用 ),以至于在它之上封装的更友好的 Keras 可能在大部分的使用者群体中会有更高的出现率。后来的 TensorFlow 中也有吸收 Keras 里面一些比较好的结构,有出现像 tf.layers 这样的更高层封装,可以期待一下 2.0 以后会不会大幅优化上层的编码 API 吧。

那这里说的高级 API 是什么呢?

官网的 guide 里面列了几个:

tf.keras

其他的还有包括像 StagingArea(构建软件流水,神器!!!)等等目前还在 tf.contrib 中处于实验阶段的很多东西,开源的力量太强大了,每隔一段时间就有很多新功能被社区添加进库中。

Estimator

tf.keras 和 Estimator 的设计都是为了让用户能够更方便地编写网络,话说简单看了下 Estimator 的用法,API 的设计方式应该大概率是从 Keras 里面借鉴的。

具体的使用这里就不多记了, 这里 写了个很小的例子,直接开始拆 Estimator 的实现吧。

核心是 tf.estimator.Estimator 这个类,先看初始化参数:

__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None,
    warm_start_from=None
)

第一个是用来构建网络的模型函数,具体后面详细分析; model_dir 用于指定保存的参数、checkpoint、log 信息等等存放的目录;再后面的几个都是一些额外的配置选项。

让我觉得非常怪的是官网的介绍页面说 Estimator 的优势是不需要用户建图……我真是一脸懵逼。或许对于 TF 内置的一些事先建好的 Estimator 是这样吧,但是如果想自定义呢?……写文档的人吹的有点过了吧。

创建 Estimator 时,首先初始化各项配置参数信息(值得一提的是 model_dir 是允许被 config 中的选项覆盖的),设置训练或者验证时用的 数据分布策略(DistributionStrategy,后面再详细分析) ,设置参数和图节点分布的 device_fn ;之后简单检查 model_fn 的参数是否符合规范,然后完处理完 warm_start 的一些设置就结束了。

传入的 model_fn 是用于构建 Estimator 代表的模型网络的核心函数,它能够接受的参数名有严格的规定:

  • features:网络的输入数据,即一个 batch 的数据;
  • labels:网络的标签数据,即一个 batch 的目标标签;
  • mode:可选,但是一般都必须要有,要不实现起来会很麻烦。这个值会根据执行的模式由 Estimator 传入,会有 3 种, tf.estimator.ModeKeys.PREDICTtf.estimator.ModeKeys.TRAINtf.estimator.ModeKeys.EVALUATE
  • params:可选,对应的是 Estimator 的初始化参数;
  • config:可选,对应的是 Estimator 的初始化参数

Run the Estimator

接下来是 Estimator 类的三个调用方法 evaluate、predict 和 train,从字面上就能够看出来各自对应的是什么功能了(Keras 里面对应的 API 应该是 evaluate、predict 和 fit)。

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None,
    yield_single_examples=True
)

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

三个方法的共同参数是这个 input_fn ,这是类似前面 model_fn 一样,也需要 Estimator 的创建者写好的输出数据产生函数。这个函数的返回值是一个二元组 (features, labels) 对应了 model_fn 的前两个输入参数。

train 中的 steps 表示从哪里开始训练,Estimator 将首先从保存的 checkpoint 中找到最接近的保存点,然后开始这次的训练,max_steps 则简单地就是训练的 batch 数了。

estimator_spec = self._call_model_fn(
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)

model_fn 中传入输入数据以及 ModeKeys.TRAIN ,接下来实际的执行函数是:

def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                               global_step_tensor, saving_listeners)

添加 TensorBoard 中的 Summary、创建参数保存点、如果有 saving_listeners 则额外添加到运行的 hooks 中,之后:

with training.MonitoredTrainingSession(
    master=self._config.master,
    is_chief=self._config.is_chief,
    checkpoint_dir=self._model_dir,
    scaffold=estimator_spec.scaffold,
    hooks=worker_hooks,
    chief_only_hooks=(
        tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
    save_checkpoint_secs=0,  # Saving is handled by a hook.
    save_summaries_steps=self._config.save_summary_steps,
    config=self._session_config,
    log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
  loss = None
  while not mon_sess.should_stop():
    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])

嗯……这段代码是不是很熟悉,没错,官方建议的常规 TensorFlow 训练代码就是要写成这个格式。

至此,train 部分基本上分析完了(带 DistributionStrategy 的版本后面再说),整个过程就是把一套常规的 TensorFlow 代码的各个部分做了几级封装,要说有什么特别的就是它把 Summary 和 Saver 都默认包含在内了。

如果按照这个格式解开成普通的 TensorFlow 代码的话,可以说是非常好的官方范例了。

EstimatorSpec

然后再注意到 model_fn 的返回值,前面也提到了 evaluate、predict 和 train 这三个实际执行的方法其实最终都是把 input_fn 中产生的数据传给 model_fn 来跑,这里的控制差别就需要配合对不同的 mode 选项的分支判断来做,所以一个 model_fn 函数写出来大概是这个样子的:

def model_fn(features, labels, mode):

    xxxxxx

    if (mode == tf.estimator.ModeKeys.PREDICT):
        ...
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    if (mode == tf.estimator.ModeKeys.EVAL):
        ...
        return tf.estimator.EstimatorSpec(mode=mode, loss=cross_entropy, eval_metric_ops=eval_metric_ops)

    if (mode == tf.estimator.ModeKeys.TRAIN):
        ...
        return tf.estimator.EstimatorSpec(mode=mode, loss=cross_entropy, train_op=train_step, eval_metric_ops=eval_metric_ops)

不管是哪种模式, model_fn 最终的返回值都需要通过 EstimatorSpec 这个结构来传出去,其属性有:

@staticmethod
__new__(
    cls,
    mode,
    predictions=None,
    loss=None,
    train_op=None,
    eval_metric_ops=None,
    export_outputs=None,
    training_chief_hooks=None,
    training_hooks=None,
    scaffold=None,
    evaluation_hooks=None,
    prediction_hooks=None
)
  • mode:对应三种不同的模式标识;
  • predictions:预测结果,要是一个 Tensor 或者 Tensor 组成的 dict;
  • loss:训练的损失函数值,必须是一个标量或者形状为 [1] 的 Tensor;
  • train_op:训练 step 的 op,一般是某个 Optimizer 的 minimize() 方法返回的那个;
  • eval_metric_ops:一个包含了验证结果的 dict,可以是 Metric 类,或者一个 (metric_tensor, update_op) 的元组;
  • 其他…略了

要求 ModeKeys.TRAIN 模式返回的必须包含 loss 和 train_op, ModeKeys.EVAL 模式返回的必须包含 lossModeKeys.PREDICT 模式返回的必须包含 predictions。


以上所述就是小编给大家介绍的《TensorFlow 拆包(九):High Level APIs》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

老二非死不可

老二非死不可

方三文 / 机械工业出版社 / 2013-12 / 39.00

关于投资 价值投资者为啥都买茅台? 怎样识别好公司与坏公司? 做空者真的罪大恶极吗? 国际板对A股会有什么影响? 波段操作,止损割肉到底靠不靠谱? IPO真的是A股萎靡不振的罪魁祸首吗? 关于商业 搜狐的再造战略有戏吗? 新浪如何焕发第二春? 百度的敌人为什么是它自己? 我为什么比巴菲特早两年投资比亚迪? 民族品牌这张牌还靠谱......一起来看看 《老二非死不可》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具