内容简介:六月底,机器之心发布了「正如我们在第一部分讨论的,每次调用 tf.get_variable() 时,都需要为变量赋予一个新的唯一名称。实际上,图中的每个张量也需要一个唯一的名称。可以通过张量、操作和变量的 .name 属性访问该名称。绝大多数情况下,名称会自动创建;例如,一个常量节点会以 Const 命名,当创建更多常量节点时,其名称将是 Const_1,Const_2 等。还可以通过 name=的属性设置节点名称,列举后缀仍会自动添加:代码:
六月底,机器之心发布了「 令人困惑的TensorFlow! 」,讲述了初上手TensorFlow时会遇到的麻烦。日前,作者更新博客,对该文写了续篇,主要讲述了保存和加载TensorFlow模型以及上下文管理器的一些问题。
命名和域
命名变量和张量
正如我们在第一部分讨论的,每次调用 tf.get_variable() 时,都需要为变量赋予一个新的唯一名称。实际上,图中的每个张量也需要一个唯一的名称。可以通过张量、操作和变量的 .name 属性访问该名称。绝大多数情况下,名称会自动创建;例如,一个常量节点会以 Const 命名,当创建更多常量节点时,其名称将是 Const_1,Const_2 等。还可以通过 name=的属性设置节点名称,列举后缀仍会自动添加:
代码:
import tensorflow as tf a = tf.constant(0.) b = tf.constant(1.) c = tf.constant(2., name="cool_const") d = tf.constant(3., name="cool_const") print a.name, b.name, c.name, d.name
输出:
Const:0 Const_1:0 cool_const:0 cool_const_1:0
虽然节点命名并非必要,但在调试时非常有用。当 Tensorflow 代码崩溃时,error trace 将指向一个特定的操作。如果有很多同类型的操作,那么很难确定是哪一个出了问题。而通过明确命名每个节点,可以获得信息详细的 error trace,并更快地识别问题。
使用范围
随着图形越来越复杂,手动命名所有内容变得愈加困难。Tensorflow 提供 tf.variable_scope 对象,它通过将图形细分为更小的组块,使图形更易梳理。通过将一段图形创建代码封装在 with tf.variable_scope(scope_name):语句中,创建的所有节点名称都将自动以 scope_name 字符串作为前缀。此外,这些作用域堆栈,在另一个范围内创建的作用域会简单地将前缀链接在一起,用斜杠分隔。
代码:
import tensorflow as tf a = tf.constant(0.) b = tf.constant(1.) with tf.variable_scope("first_scope"): c = a + b d = tf.constant(2., name="cool_const") coef1 = tf.get_variable("coef", [], initializer=tf.constant_initializer(2.)) with tf.variable_scope("second_scope"): e = coef1 * d coef2 = tf.get_variable("coef", [], initializer=tf.constant_initializer(3.)) f = tf.constant(1.) g = coef2 * f print a.name, b.name print c.name, d.name print e.name, f.name, g.name print coef1.name print coef2.name
输出:
Const:0 Const_1:0 first_scope/add:0 first_scope/cool_const:0 first_scope/second_scope/mul:0 first_scope/second_scope/Const:0 first_scope/second_scope/mul_1:0 first_scope/coef:0 first_scope/second_scope/coef:0
我们能够使用代码 coef 创建两个名称相同的变量。这是因为作用域可以将名称转换为 first_scope/coef:0 和 first_scope/second_scope/coef:0,它们是不同的。
保存和加载
训练好的神经网络包括两个基本组成部分:
-
已经学习过某些任务优化的网络权重
-
说明如何利用权重获得结果的网络图
Tensorflow 将这两个组件分开,但很明显它们需要紧密匹配。如果没有图结构进行说明,那权重也无用,而带有随机权重的图也效果也不好。事实上,即使仅交换两个权重矩阵也可能完全破坏模型。这通常会让 Tensorflow 初学者感觉很挫败。使用预先训练好的模型作为神经网络的一个组成部分不失为加速训练的好方法,但是也有可能搞砸一切。
保存模型
当只有单个模型时,Tensorflow 用于保存和加载的内置 工具 使用很方便:只需创建一个 tf.train.Saver()。类似于 tf.train.Optimizer,tf.train.Saver 本身并不是一个节点,而是在已有图形上执行有用功能的更高级类别。你可能已经预料到 tf.train 的「有用功能」了,即保存和加载模型。
代码:
import tensorflow as tf a = tf.get_variable('a', []) b = tf.get_variable('b', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.save(sess, './tftcp.model')
输出
四个新文件:
checkpoint tftcp.model.data-00000-of-00001 tftcp.model.index tftcp.model.meta
具体内容分析如下:
首先:当我们只保存一个模型时,为什么会输出四个文件?重建模型所需的信息被分散到它们当中。如果想复制或者备份模型,需要有四个文件(前缀为文件名)。下面简述答案:
-
tftcp.model.data-00000-of-00001 包含模型权重(上述第一个要点)。它可能这里最大的文件。
-
tftcp.model.meta 是模型的网络结构(上述第二个要点)。它包含重建图形所需的所有信息。
-
tftcp.model.index 是连接前两点的索引结构。用于在数据文件中找到对应节点的参数。
-
checkpoint 实际上不需要重建模型,但如果在整个训练过程中保存了多个版本的模型,那它会跟踪所有内容。
其次,我为什么一定要为该示例创建 tf.Session 和 tf.global_variables_initializer 呢?
因为,如果要保存一个模型,我们需要保存相关的内容。计算存于图中,但数值存于会话中。tf.train.Saver 可以通过指向图表的全局指针访问网络结构。但当我们保存变量的值(即网络权重)时,我们需要访问 tf.Session 来确定这些值;这就是为什么 sess 作为 save 函数的第一个参数传入。此外,尝试保存未初始化的变量会引发错误,因为尝试访问未初始化变量的值总是会引发错误。因此,我们需要一个会话和一个初始化程序(或等价的 tf.assign)。
加载模型
既然我们已经保存了模型,现在重新加载它。第一步是重新创建变量:我们希望变量的名称、形状和类型都与保存时一致。第二步是创建与之前一样的 tf.train.Saver,并调用 restore 函数。
代码:
import tensorflow as tf a = tf.get_variable('a', []) b = tf.get_variable('b', []) saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, './tftcp.model') sess.run([a,b])
输出:
[1.3106428, 0.6413864]
在运行之前,我们不需要初始化 a 或 b!这是因为 restore 运算将值从文件移动到会话的变量中。由于会话不再包含任何空值变量,因此不再需要初始化。(如果不小心,会适得其反:还原后运行 init 会使随机初始化的值覆盖加载的值。)
选择变量
当一个 tf.train.Saver 程序初始化后,它会查看当前图形并获取变量列表;这是 saver「关心」的永久存储的变量列表。我们可以用._var_list 属性来检查:
代码:
import tensorflow as tf a = tf.get_variable('a', []) b = tf.get_variable('b', []) saver = tf.train.Saver() c = tf.get_variable('c', []) print saver._var_list
输出:
[<tf.Variable 'a:0' shape=() dtype=float32_ref>, <tf.Variable 'b:0' shape=() dtype=float32_ref>]
因为在创建 saver 时 c 还没有出现,所以它并没有成为函数的一部分。一般来说,你要在创建 saver 之前确保已经创建了所有的变量。
当然,在某些特定的情况下,可能只需保存变量的一个子集。当创建 var_list 以期望它跟踪可用变量子集时,tf.train.Saver 允许传递 var_list。
代码:
import tensorflow as tf a = tf.get_variable('a', []) b = tf.get_variable('b', []) c = tf.get_variable('c', []) saver = tf.train.Saver(var_list=[a,b]) print saver._var_list
输出:
[<tf.Variable 'a:0' shape=() dtype=float32_ref>, <tf.Variable 'b:0' shape=() dtype=float32_ref>]
加载修正模型
上面例子中涵盖的模型加载方案类似于物理中的「真空中无摩擦的完美球体」(perfect sphere in frictionless vacuum)场景。只要你使用自己的代码保存和加载模型,且不擅自更改二者,实现保存和加载轻而易举。但很多情况下,并不会有如此完美的场景。在这些情况下,我们需要多加思量。
让我们通过几个场景来说明这些问题。首先,如果我们想保存一个完整的模型,但只想加载其中的一部分怎么办?(在下面代码示例中,我依次运行两个脚本。)
代码:
import tensorflow as tf a = tf.get_variable('a', []) b = tf.get_variable('b', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.save(sess, './tftcp.model')
import tensorflow as tf a = tf.get_variable('a', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.restore(sess, './tftcp.model') sess.run(a)
输出:
1.1700551
OK。当我们在相反的场景里,就会出现失败的状况:我们希望将一个模型作为大型模型的组件加载。
代码:
import tensorflow as tf a = tf.get_variable('a', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.save(sess, './tftcp.model') import tensorflow as tf a = tf.get_variable('a', []) d = tf.get_variable('d', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.restore(sess, './tftcp.model')
输出:
Key d not found in checkpoint [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
我们只想加载 a,却忽略了新变量 b。我们犯了一个错误,却抱怨 d 没有出现在 checkpoint 中。
第三种情况是,我们想将一个模型的参数加载到另一个模型的计算图中。这也会引发一个错误,原因很明显:Tensorflow 不知道把加载的所有参数放置在何处。幸好有个方法可以给它点提示。
还记得 var_list 吗?或者更准确来说是「var_list_or_dictionary_mapping_names_to_vars」,但这个名字有点拗口,所以他们使用第一个。
保存模型是 Tensorflow 要求使用全局唯一变量名的关键原因之一。在保存-模型-文件中,每个保存变量的名称都与其形状和值有关。将其加载到新的计算图中与将想要加载的变量的原始名称映射到当前模型的变量中一样简单。示例如下:
代码:
import tensorflow as tf a = tf.get_variable('a', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.save(sess, './tftcp.model') import tensorflow as tf d = tf.get_variable('d', []) init = tf.global_variables_initializer() saver = tf.train.Saver(var_list={'a': d}) sess = tf.Session() sess.run(init) saver.restore(sess, './tftcp.model') sess.run(d)
输出:
-0.9303965
这是一种关键机制,通过这个机制,可以将没有相同计算图的模型组合在一起。例如,你可能从网上获得了一个预训练好的语言模型,希望重用词嵌入。或者你可能在两次训练之间改变了模型的参数化,想让这个新版本在旧版本的基础上继续前进;但你又不想重新训练整个过程。在这两种情况下,你只需手动创建一个字典,将旧变量名称映射到新变量即可。
需要注意的是:你要明确地知道正在加载的参数是如何使用的。如果可以,你应该使用原作者用来构建模型的确切代码,以确保计算图的组件与训练时看起来一样。如果需要复现模型,务必记住,无论多微小的更改,都可能严重损害预训练网络的性能。所以始终要将复现结果和原来的结果进行对比。
模型检查
如果想加载的模型来源于网络或由自己创建(两个月前),那你很可能不知道原始变量是如何命名的。要检查保存的模型,需要使用官方 Tensorflow 库的一些工具。
代码:
import tensorflow as tf a = tf.get_variable('a', []) b = tf.get_variable('b', [10,20]) c = tf.get_variable('c', []) init = tf.global_variables_initializer() saver = tf.train.Saver() sess = tf.Session() sess.run(init) saver.save(sess, './tftcp.model') print tf.contrib.framework.list_variables('./tftcp.model')
输出:
[('a', []), ('b', [10, 20]), ('c', [])]
利用这些工具(结合原始代码库一起使用)通常可以找到你想要的变量名称。
结论
希望本文能帮你了解关于保存和加载 Tensorflow 模型的基础知识。还有其他一些高级技巧,比如自动 checkpoint 和保存/恢复元图,可能会在以后的文章中提到;但是根据我的经验,这些并不常用,特别是对于初学者来说。
原文链接: https://jacobbuckman.com/post/tensorflow-the-confusing-parts-2/
以上所述就是小编给大家介绍的《令人困惑的 TensorFlow!(II)》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!
猜你喜欢:- 令人困惑的strtotime
- ???? 令人愉快的 JavaScript 测试
- 毕业 2020:令人心碎的 offer
- 12 个令人惊叹的 CSS 项目
- 12个令人惊叹的CSS实验项目
- 令人心动的HTTP知识点大全
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
人人都是产品经理——写给产品新人
苏杰 / 电子工业出版社 / 2017-6 / 66.60
《人人都是产品经理——写给产品新人》为经典畅销书《人人都是产品经理》的内容升级版本,和《人人都是产品经理2.0——写给泛产品经理》相当于上下册的关系。对于大量成长起来的优秀互联网产品经理、众多想投身产品工作的其他岗位从业者,以及更多有志从事这一职业的学生而言,这《人人都是产品经理——写给产品新人》曾是他们记忆深刻的启蒙读物、思想基石和行动手册。作者以分享经历与体会为出发点,以“朋友间聊聊如何做产品......一起来看看 《人人都是产品经理——写给产品新人》 这本书的介绍吧!
图片转BASE64编码
在线图片转Base64编码工具
HEX CMYK 转换工具
HEX CMYK 互转工具