如何合并两个 TensorFlow 模型

栏目: 编程工具 · 发布时间: 6年前

内容简介:这是Tensorflow SavedModel模型系列文章的第三篇,也是终章。在《为什么需要合并两个模型?我们还是以《

这是Tensorflow SavedModel模型系列文章的第三篇,也是终章。在《 Tensorflow SavedModel模型的保存与加载 》中,我们谈到了Tensorflow模型如何保存为SavedModel格式,以及如何加载之。在《 如何查看tensorflow SavedModel格式模型的信息 》中,我们演示了如何查看模型的signature和计算图结构。在本文中,我们将探讨如何合并两个模型,简单的说,就是将第一个模型的输出,作为第二个模型的输入,串联起来形成一个新模型。

背景

为什么需要合并两个模型?

我们还是以《 Tensorflow SavedModel模型的保存与加载 》中的代码为例,这个手写数字识别模型接收的输入是shape为[?, 784],这里?代表可以批量接收输入,可以先忽略,就把它固定为1吧。784是28 x 28进行展开的结果,也就是28 x 28灰度图像展开的结果。

问题是,我们送给模型的通常是图片,可能来自文件、可能来自摄像头。让问题变得复杂的是,如果我们通过HTTP来调用部署到服务器端的模型,二进制数据实际上是不方便HTTP传输的,这时我们通常需要对图像数据进行base64编码。这样服务器端接收到的数据是一个base64字符串,可模型接受的是二进制向量。

很自然的,我们可以想到两种解决方法:

  1. 重新训练模型一个接收base64字符串的模型。

    这种解决方法的问题在于:重新训练模型很费时,甚至不可行。本文示例因为比较简单,重新训练也没啥。如果是那种很深的卷积神经网络,训练一次可能需要好几天,重新训练代价很大。更普遍的情况是,我们使用的是别人训练好的模型,比如图像识别中普遍使用的Mobilenet、InceptionV3等等,都是Google、微软这样的公司,耗费大量的资源训练出来的,我们没有那个条件重新训练。

  2. 在服务器端增加base64到二进制数据的转换

    这种解决方法实现起来不复杂,但如果我们使用的是Tensorflow model server之类的方案部署的呢?当然我们也可以再开启一个server,来接受客户端的base64图像数据,处理完毕之后再转发给Tensorflow model server,但这无疑增加了服务端的工作量,增加了服务端的复杂性。

在本文,我们将给出第三种方案:编写一个Tensorflow模型,接收base64的图像数据,输出二进制向量,然后将第一个模型的输出作为第二个模型的输入,串接起来,保存为一个新的模型,最后部署新的模型。

base64解码Tensorflow模型

Tensorflow包含了大量图像处理和数组处理的方法,所以实现这个模型比较简单,模型包含了base64解码、解码PNG图像、缩放到28 * 28、最后展开为(1, 784)的数组输出,符合手写数字识别模型的输入,代码如下:

with tf.Graph().as_default() as g1:
  base64_str = tf.placeholder(tf.string, name='input_string')
  input_str = tf.decode_base64(base64_str)
  decoded_image = tf.image.decode_png(input_str, channels=1)
  # Convert from full range of uint8 to range [0,1] of float32.
  decoded_image_as_float = tf.image.convert_image_dtype(decoded_image,
                                                        tf.float32)
  decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
  resize_shape = tf.stack([28, 28])
  resize_shape_as_int = tf.cast(resize_shape, dtype=tf.int32)
  resized_image = tf.image.resize_bilinear(decoded_image_4d,
                                           resize_shape_as_int)
  # 展开为1维数组
  resized_image_1d = tf.reshape(resized_image, (-1, 28 * 28))
  print(resized_image_1d.shape)
  tf.identity(resized_image_1d, name="DecodeJPGOutput")

g1def = g1.as_graph_def()

在该模型中,并不存在变量,都是一些固定的操作,所以无需进行训练。

加载手写识别模型

手写识别模型参考《 Tensorflow SavedModel模型的保存与加载 》一文,模型保存在 “./model” 下,加载代码如下:

with tf.Graph().as_default() as g2:
  with tf.Session(graph=g2) as sess:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        "./model", tag_constants.SERVING).graph_def

    tf.saved_model.loader.load(sess, ["serve"], "./model")

    g2def = graph_util.convert_variables_to_constants(
        sess,
        input_graph_def,
        ["myOutput"],
        variable_names_whitelist=None,
        variable_names_blacklist=None)

这里使用了g2定义了另外一个graph,和前面的模型的graph区分开来。注意这里调用了 graph_util.convert_variables_to_constants 将模型中的变量转化为常量,也就是所谓的冻结图(freeze graph)操作。

在研究如何连接两个模型时,我在这个问题上卡了很久。先的想法是合并模型之后,再加载变量值进来,但是尝试之后,怎么也不成功。后来的想法是遍历手写识别模型的变量,获取其变量值,将变量值复制到合并的模型的变量,但这样操作,使用模型时,总是提示有变量未初始化。

最后从Tensorflow模型到Tensorflow lite模型转换中获得了灵感,将模型中的变量固定下来,这样就不存在变量的加载问题,也不会出现模型变量未初始化的问题。

执行 convert_variables_to_constants 后,可以看到有两个变量转化为了常量操作,也就是手写数字识别模型中的 wb

Converted 2 variables to const ops.

连接两个模型

利用 tf.import_graph_def 方法,我们可以导入图到现有图中,注意第二个 import_graph_def ,其input是第一个graph_def的输出,通过这样的操作,就将两个计算图连接起来,最后保存起来。代码如下:

with tf.Graph().as_default() as g_combined:
  with tf.Session(graph=g_combined) as sess:

    x = tf.placeholder(tf.string, name="base64_input")

    y, = tf.import_graph_def(g1def, input_map={"input_string:0": x}, return_elements=["DecodeJPGOutput:0"])

    z, = tf.import_graph_def(g2def, input_map={"myInput:0": y}, return_elements=["myOutput:0"])
    tf.identity(z, "myOutput")

    tf.saved_model.simple_save(sess,
              "./modelbase64",
              inputs={"base64_input": x},
              outputs={"myOutput": z})

因为第一个模型不包含变量,第二个模型的变量转化为了常量操作,所以最后保存的模型文件并不包含变量:

modelbase64/
├── saved_model.pb
└── variables

1 directory, 1 file

测试

我们写一段测试代码,测试一下合并之后模型是否管用,代码如下:

with tf.Session(graph=tf.Graph()) as sess:
  sess.run(tf.global_variables_initializer())

  tf.saved_model.loader.load(sess, ["serve"], "./modelbase64")
  graph = tf.get_default_graph()

  with open("./5.png", "rb") as image_file:
    encoded_string = str(base64.urlsafe_b64encode(image_file.read()), "utf-8")

  x = sess.graph.get_tensor_by_name('base64_input:0')
  y = sess.graph.get_tensor_by_name('myOutput:0')

  scores = sess.run(y,
           feed_dict={x: encoded_string})
  print("predict: %d, actual: %d" % (np.argmax(scores, 1), 5))

这里模型的输入为 base64_input ,输出仍然是 myOutput ,使用两个图片测试,均工作正常。

小结

最近三篇文章其实都是在研究我的微信小程序时总结的,为了更好的说明问题,我使用了一个非常简单的模型来说明问题,但同样适用于复杂的模型。

本文的完整代码请参考:https://github.com/mogoweb/aiexamples/tree/master/tensorflow/saved_model

希望这篇文章对您有帮助,感谢阅读!同时敬请关注我的微信公众号:云水木石。

如何合并两个 TensorFlow 模型


以上所述就是小编给大家介绍的《如何合并两个 TensorFlow 模型》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

互联网的误读

互联网的误读

詹姆斯•柯兰(James Curran)、娜塔莉•芬顿(Natalie Fenton)、德 斯•弗里德曼(Des Freedman) / 何道宽 / 中国人民大学出版社 / 2014-7-1 / 45.00

互联网的发展蔚为壮观。如今,全球的互联网用户达到20亿之众,约占世界人口的30%。这无疑是一个新的现象,对于当代各国的经济、政治和社会生活意义重大。有关互联网的大量大众读物和学术著作鼓吹其潜力将从根本上被重新认识,这在20世纪90年代中期一片唱好时表现尤甚,那时许多论者都对互联网敬畏三分,惊叹有加。虽然敬畏和惊叹可能已成过去,然而它背后的技术中心主义——相信技术决定结果——却阴魂不散,与之伴生的则......一起来看看 《互联网的误读》 这本书的介绍吧!

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

HTML 编码/解码
HTML 编码/解码

HTML 编码/解码

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具