内容简介:本文翻译自:为了简单起见,在之前的大多数示例中,我们都是手动创建一个会话(session),并不关心保存和加载检查点,但在实践中通常不是这样做的。在这我推荐你使用当利用神经网络训练模型进行实验时,通常需要分割训练集和测试集。你需要利用训练集训练你的模型,并在测试集中计算一些指标来评估模型的好坏。你还需要将模型参数存储为一个检查点(checkpoint),因为你需要可以随时停止并重启训练过程。TensorFlow的learn API旨在简化这项工作,使我们能够专注于开发实际模型。
本文翻译自: 《Building a neural network training framework with learn API》 , 如有侵权请联系删除,仅限于学术交流,请勿商用。如有谬误,请联系指出。
为了简单起见,在之前的大多数示例中,我们都是手动创建一个会话(session),并不关心保存和加载检查点,但在实践中通常不是这样做的。在这我推荐你使用 learn API
来进行会话管理和日志记录(session management and logging)。我们使用TensorFlow提供了一个简单而实用的 框架 来训练神经网络。在这一节中,我们将解释这个框架是如何工作的。
当利用神经网络训练模型进行实验时,通常需要分割训练集和测试集。你需要利用训练集训练你的模型,并在测试集中计算一些指标来评估模型的好坏。你还需要将模型参数存储为一个检查点(checkpoint),因为你需要可以随时停止并重启训练过程。TensorFlow的learn API旨在简化这项工作,使我们能够专注于开发实际模型。
使用 tf.learn
API的最简单的方式是直接使用 tf.Estimator
对象。你需要定义一个模型函数,该模型函数包含一个损失函数(loss function)、一个训练操作(train op)、一个或一组预测,以及一组可选的用于评估的度量操作:
import tensorflow as tf def model_fn(features, labels, mode, params): predictions = ... loss = ... train_op = ... metric_ops = ... return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metric_ops) params = ... run_config = tf.contrib.learn.RunConfig(model_dir=FLAGS.output_dir) estimator = tf.estimator.Estimator( model_fn=model_fn, config=run_config, params=params) 复制代码
要训练模型,你只需调用 Estimator.train()
函数,同时提供一个输入函数来读取数据即可:
def input_fn(): features = ... labels = ... return features, labels estimator.train(input_fn=input_fn, max_steps=...) 复制代码
如果想要评估模型,只需要调用 Estimator.evaluate()
:
estimator.evaluate(input_fn=input_fn) 复制代码
对于一些简单的情况,Estimator对象就已经足够应付了,但是TensorFlow还提供了一个更高级别的对象,称为** Experiment
** ,它提供了一些额外的实用功能。创建一个experiment对象非常简单:
experiment = tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn) 复制代码
现在我们可以调用 train_and_evaluate
函数来计算训练时的指标:
experiment.train_and_evaluate() 复制代码
运行 experiment
的另一种更为高级的方法是使用 learn_runner.run()
函数。下面是我们在框架中提供的主要功能:
import tensorflow as tf tf.flags.DEFINE_string("output_dir", "", "Optional output dir.") tf.flags.DEFINE_string("schedule", "train_and_evaluate", "Schedule.") tf.flags.DEFINE_string("hparams", "", "Hyper parameters.") FLAGS = tf.flags.FLAGS def experiment_fn(run_config, hparams): estimator = tf.estimator.Estimator( model_fn=make_model_fn(), config=run_config, params=hparams) return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=make_input_fn(tf.estimator.ModeKeys.TRAIN, hparams), eval_input_fn=make_input_fn(tf.estimator.ModeKeys.EVAL, hparams)) def main(unused_argv): run_config = tf.contrib.learn.RunConfig(model_dir=FLAGS.output_dir) hparams = tf.contrib.training.HParams() hparams.parse(FLAGS.hparams) estimator = tf.contrib.learn.learn_runner.run( experiment_fn=experiment_fn, run_config=run_config, schedule=FLAGS.schedule, hparams=hparams) if __name__ == "__main__": tf.app.run() 复制代码
调度标志(schedule flag)决定 Experiment
对象的哪个成员函数被调用。因此,如果你将schedule设置为 “train_and_evaluate”
, experiment.train_and_evaluate()
这个函数将会被调用。
def input_fn(): features = ... labels = ... return features, labels 复制代码
有关如何使用数据集API读取数据的示例,请参见 mnist .py 。要了解在TensorFlow中读取数据的各种方法,可以参考 这段代码 。
该框架还提供了一个简单的卷积网络分类器,详见 alexnet.py ,其中包括一个示例模型。
这就是开始使用TensorFlow learn API所需要的全部内容。我建议查看 框架源码 并查看官方python API,以了解更多关于learn API的信息。
以上所述就是小编给大家介绍的《【译】Effective TensorFlow Chapter13——在TensorFlow中利用learn API构建神经网络框架》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- 使用 Neuroph Java 框架创建人工神经网络
- 神经网络翻译(nmt)框架 Marian : MarianNMT
- 基于Keras框架对抗神经网络DCGAN实践
- 神经网络框架Chainer发布2.0正式版:CuPy独立
- TensorSpace.js:用于构建神经网络 3D 可视化应用的框架
- NYU、AWS联合推出:全新图神经网络框架DGL正式发布
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Web API的设计与开发
[日] 水野贵明 / 盛荣 / 人民邮电出版社 / 2017-6 / 52.00元
本书结合丰富的实例,详细讲解了Web API的设计、开发与运维相关的知识。第1章介绍Web API的概要;第2章详述端点的设计与请求的形式;第3章介绍响应数据的设计;第4章介绍如何充分利用HTTP协议规范;第5章介绍如何开发方便更改设计的Web API;第6章介绍如何开发牢固的Web API。 本书不仅适合在工作中需要设计、开发或修改Web API的技术人员阅读,对想了解技术细节的产品经理、运维人......一起来看看 《Web API的设计与开发》 这本书的介绍吧!