Tensorflow SavedModel 模型的保存与加载

栏目: 数据库 · 发布时间: 6年前

内容简介:这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Se

这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。

为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用 python 语言训练模型,然后在 Java 中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。

SavedModel包含啥?

一个比较完整的SavedModel模型包含以下内容:

assets/
assets.extra/
variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

saved_model.pb是MetaGraphDef,它包含图形结构。variables文件夹保存训练所习得的权重。assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。

MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。

assets和assets.extra是可选的,比如本文示例代码保存的模型只包含以下的内容:

variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

保存

为了简单起见,我们使用一个非常简单的手写识别代码作为示例,代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

这段代码很简单,一个简单的梯度递减回归模型。要保存该模型,我们还需要对代码作一点小小的改动。

添加命名

在输入和输出Ops中添加名称,这样我们在加载时可以方便的按名称引用操作。将上述的x赋值语句修改为:

x = tf.placeholder(tf.float32, [None, 784], name="myInput")

当然你也可以不给名称,系统会默认给一个名称,比如上面的x系统会给一个”Placeholder”,当我们需要引用多个op的时候,给每个op一个命名,确实方便给我们后面使用。

你也可以使用tf.identity给tensor命名,比如在上述代码上添加一行:

tf.identity(y, name="myOutput")

给输出也命一个名。

保存到文件

最简单的保存方法是使用tf.saved_model.simple_save函数,代码如下:

tf.saved_model.simple_save(sess,
            "./model",
            inputs={"myInput": x},
            outputs={"myOutput": y})

这段代码将模型保存在 ./model 目录。

当然你也可以采用比较复杂的写法:

builder = tf.saved_model.builder.SavedModelBuilder("./model")

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tag_constants.SERVING],
                                     signature_def_map={'predict': signature})
builder.save()

看起来新的代码差别不大,区别就在于可以自己定义tag,在签名的定义上更加灵活。这里说说tag的用途吧。

一个模型可以包含不同的MetaGraphDef,什么时候需要多个MetaGraphDef呢?也许你想保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。

在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。

加载

对不同语言而言,加载过程有些类似,这里还是以python为例:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ["serve"], "./model")
  graph = tf.get_default_graph()

  input = np.expand_dims(mnist.test.images[0], 0)
  x = sess.graph.get_tensor_by_name('myInput:0')
  y = sess.graph.get_tensor_by_name('myOutput:0')
  batch_xs, batch_ys = mnist.test.next_batch(1)
  scores = sess.run(y,
           feed_dict={x: batch_xs})
  print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))

需要注意,load函数中第二个参数是tag,需要和保存模型时的参数一致,第三个参数是模型保存的文件夹。

调用load函数后,不仅加载了计算图,还加载了训练中习得的变量值,有了这两者,我们就可以调用其进行推断新给的测试数据。

小结

将过程捋顺了之后,你会发觉保存和加载SavedModel其实很简单。但在摸索过程中,也走了不少的弯路,主要原因是现在搜索到的大部分资料还是用tf.train.Saver()来保存模型,还有的是用tf.gfile.FastGFile来序列化模型图。

本文的完整代码请参考:https://github.com/mogoweb/aiexamples/tree/master/tensorflow/saved_model

希望这篇文章对您有帮助,感谢阅读!

Tensorflow SavedModel 模型的保存与加载


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

精通Java并发编程(第2版)

精通Java并发编程(第2版)

[西] 哈维尔·费尔南德斯·冈萨雷斯 / 唐富年 / 人民邮电出版社 / 2018-10 / 89.00元

Java 提供了一套非常强大的并发API,可以轻松实现任何类型的并发应用程序。本书讲述Java 并发API 最重要的元素,包括执行器框架、Phaser 类、Fork/Join 框架、流API、并发数据结构、同步机制,并展示如何在实际开发中使用它们。此外,本书还介绍了设计并发应用程序的方法论、设计模式、实现良好并发应用程序的提示和技巧、测试并发应用程序的工具和方法,以及如何使用面向Java 虚拟机的......一起来看看 《精通Java并发编程(第2版)》 这本书的介绍吧!

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

html转js在线工具
html转js在线工具

html转js在线工具

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试