Tensorflow SavedModel 模型的保存与加载

栏目: 数据库 · 发布时间: 6年前

内容简介:这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Se

这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。

为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用 python 语言训练模型,然后在 Java 中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。

SavedModel包含啥?

一个比较完整的SavedModel模型包含以下内容:

assets/
assets.extra/
variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

saved_model.pb是MetaGraphDef,它包含图形结构。variables文件夹保存训练所习得的权重。assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。

MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。

assets和assets.extra是可选的,比如本文示例代码保存的模型只包含以下的内容:

variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

保存

为了简单起见,我们使用一个非常简单的手写识别代码作为示例,代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

这段代码很简单,一个简单的梯度递减回归模型。要保存该模型,我们还需要对代码作一点小小的改动。

添加命名

在输入和输出Ops中添加名称,这样我们在加载时可以方便的按名称引用操作。将上述的x赋值语句修改为:

x = tf.placeholder(tf.float32, [None, 784], name="myInput")

当然你也可以不给名称,系统会默认给一个名称,比如上面的x系统会给一个”Placeholder”,当我们需要引用多个op的时候,给每个op一个命名,确实方便给我们后面使用。

你也可以使用tf.identity给tensor命名,比如在上述代码上添加一行:

tf.identity(y, name="myOutput")

给输出也命一个名。

保存到文件

最简单的保存方法是使用tf.saved_model.simple_save函数,代码如下:

tf.saved_model.simple_save(sess,
            "./model",
            inputs={"myInput": x},
            outputs={"myOutput": y})

这段代码将模型保存在 ./model 目录。

当然你也可以采用比较复杂的写法:

builder = tf.saved_model.builder.SavedModelBuilder("./model")

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tag_constants.SERVING],
                                     signature_def_map={'predict': signature})
builder.save()

看起来新的代码差别不大,区别就在于可以自己定义tag,在签名的定义上更加灵活。这里说说tag的用途吧。

一个模型可以包含不同的MetaGraphDef,什么时候需要多个MetaGraphDef呢?也许你想保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。

在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。

加载

对不同语言而言,加载过程有些类似,这里还是以python为例:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ["serve"], "./model")
  graph = tf.get_default_graph()

  input = np.expand_dims(mnist.test.images[0], 0)
  x = sess.graph.get_tensor_by_name('myInput:0')
  y = sess.graph.get_tensor_by_name('myOutput:0')
  batch_xs, batch_ys = mnist.test.next_batch(1)
  scores = sess.run(y,
           feed_dict={x: batch_xs})
  print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))

需要注意,load函数中第二个参数是tag,需要和保存模型时的参数一致,第三个参数是模型保存的文件夹。

调用load函数后,不仅加载了计算图,还加载了训练中习得的变量值,有了这两者,我们就可以调用其进行推断新给的测试数据。

小结

将过程捋顺了之后,你会发觉保存和加载SavedModel其实很简单。但在摸索过程中,也走了不少的弯路,主要原因是现在搜索到的大部分资料还是用tf.train.Saver()来保存模型,还有的是用tf.gfile.FastGFile来序列化模型图。

本文的完整代码请参考:https://github.com/mogoweb/aiexamples/tree/master/tensorflow/saved_model

希望这篇文章对您有帮助,感谢阅读!

Tensorflow SavedModel 模型的保存与加载


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Creative Selection

Creative Selection

Ken Kocienda / St. Martin's Press / 2018-9-4 / USD 28.99

Hundreds of millions of people use Apple products every day; several thousand work on Apple's campus in Cupertino, California; but only a handful sit at the drawing board. Creative Selection recounts ......一起来看看 《Creative Selection》 这本书的介绍吧!

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具