内容简介:这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Se
这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。
为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用 python 语言训练模型,然后在 Java 中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。
SavedModel包含啥?
一个比较完整的SavedModel模型包含以下内容:
assets/ assets.extra/ variables/ variables.data-*****-of-***** variables.index saved_model.pb
saved_model.pb是MetaGraphDef,它包含图形结构。variables文件夹保存训练所习得的权重。assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。
MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。
assets和assets.extra是可选的,比如本文示例代码保存的模型只包含以下的内容:
variables/ variables.data-*****-of-***** variables.index saved_model.pb
保存
为了简单起见,我们使用一个非常简单的手写识别代码作为示例,代码如下:
from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) sess = tf.InteractiveSession() x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b) y_ = tf.placeholder(tf.float32, [None, 10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) tf.global_variables_initializer().run() for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) train_step.run({x: batch_xs, y_: batch_ys}) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
这段代码很简单,一个简单的梯度递减回归模型。要保存该模型,我们还需要对代码作一点小小的改动。
添加命名
在输入和输出Ops中添加名称,这样我们在加载时可以方便的按名称引用操作。将上述的x赋值语句修改为:
x = tf.placeholder(tf.float32, [None, 784], name="myInput")
当然你也可以不给名称,系统会默认给一个名称,比如上面的x系统会给一个”Placeholder”,当我们需要引用多个op的时候,给每个op一个命名,确实方便给我们后面使用。
你也可以使用tf.identity给tensor命名,比如在上述代码上添加一行:
tf.identity(y, name="myOutput")
给输出也命一个名。
保存到文件
最简单的保存方法是使用tf.saved_model.simple_save函数,代码如下:
tf.saved_model.simple_save(sess, "./model", inputs={"myInput": x}, outputs={"myOutput": y})
这段代码将模型保存在 ./model 目录。
当然你也可以采用比较复杂的写法:
builder = tf.saved_model.builder.SavedModelBuilder("./model") signature = predict_signature_def(inputs={'myInput': x}, outputs={'myOutput': y}) builder.add_meta_graph_and_variables(sess=sess, tags=[tag_constants.SERVING], signature_def_map={'predict': signature}) builder.save()
看起来新的代码差别不大,区别就在于可以自己定义tag,在签名的定义上更加灵活。这里说说tag的用途吧。
一个模型可以包含不同的MetaGraphDef,什么时候需要多个MetaGraphDef呢?也许你想保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。
在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。
加载
对不同语言而言,加载过程有些类似,这里还是以python为例:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) with tf.Session(graph=tf.Graph()) as sess: tf.saved_model.loader.load(sess, ["serve"], "./model") graph = tf.get_default_graph() input = np.expand_dims(mnist.test.images[0], 0) x = sess.graph.get_tensor_by_name('myInput:0') y = sess.graph.get_tensor_by_name('myOutput:0') batch_xs, batch_ys = mnist.test.next_batch(1) scores = sess.run(y, feed_dict={x: batch_xs}) print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))
需要注意,load函数中第二个参数是tag,需要和保存模型时的参数一致,第三个参数是模型保存的文件夹。
调用load函数后,不仅加载了计算图,还加载了训练中习得的变量值,有了这两者,我们就可以调用其进行推断新给的测试数据。
小结
将过程捋顺了之后,你会发觉保存和加载SavedModel其实很简单。但在摸索过程中,也走了不少的弯路,主要原因是现在搜索到的大部分资料还是用tf.train.Saver()来保存模型,还有的是用tf.gfile.FastGFile来序列化模型图。
本文的完整代码请参考:https://github.com/mogoweb/aiexamples/tree/master/tensorflow/saved_model
希望这篇文章对您有帮助,感谢阅读!
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:- OpenGL 3D 模型加载和渲染
- TensorFlow 加载多个模型的方法
- Viewer模型加载本地缓存实战
- [译] 保存和加载模型教程(PyTorch)
- ember.js – 使用EmberData手动加载模型
- Laravel Database——Eloquent Model 关联模型加载与查询
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
你必须知道的213个C语言问题
范立锋、李世欣 / 人民邮电出版社 / 2010-6 / 45.00元
《你必须知道的213个C语言问题》精选了213个在C语言程序设计中经常遇到的问题,目的是帮助读者解决在C语言学习和开发中遇到的实际困难,提高读者学习和开发的效率。这些问题涵盖了C语言与软件开发、C语言基础、编译预处理、字符串、函数、键盘操作、文件、目录和磁盘、数组、指针和结构、DOS服务和BIOS服务、日期和时间、重定向I/O和进程命令、C语言开发常见错误及程序调试等内容,均是作者经过充分的调研,......一起来看看 《你必须知道的213个C语言问题》 这本书的介绍吧!