TensorFlow 加载多个模型的方法

栏目: 数据库 · 发布时间: 6年前

内容简介:采用 TensorFlow 的时候,有时候我们需要加载的不止是一个模型,那么如何加载多个模型呢?原文:关于 TensorFlow 可以有很多东西可以说。但这次我只介绍如何导入训练好的模型(图),因为我做不到导入第二个模型并将它和第一个模型一起使用。并且,这种导入非常慢,我也不想重复做第二次。另一方面,将一切东西都放到一个模型也不实际。

采用 TensorFlow 的时候,有时候我们需要加载的不止是一个模型,那么如何加载多个模型呢?

原文: bretahajek.com/2017/04/imp…

关于 TensorFlow 可以有很多东西可以说。但这次我只介绍如何导入训练好的模型(图),因为我做不到导入第二个模型并将它和第一个模型一起使用。并且,这种导入非常慢,我也不想重复做第二次。另一方面,将一切东西都放到一个模型也不实际。

在这个教程中,我会介绍如何保存和载入模型,更进一步,如何加载多个模型。

加载 TensorFlow 模型

在介绍加载多个模型之前,我们先介绍下如何加载单个模型,官方文档: www.tensorflow.org/programmers…

首先,我们需要创建一个模型,训练并保存它。这部分我不想过多介绍细节,只需要关注如何保存模型以及不要忘记给每个操作命名。

创建一个模型,训练并保存的代码如下:

import tensorflow as tf
### Linear Regression 线性回归###
# Input placeholders
x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')
# Model parameters 定义模型的权值参数
W1 = tf.Variable([0.1], tf.float32)
W2 = tf.Variable([0.1], tf.float32)
W3 = tf.Variable([0.1], tf.float32)
b = tf.Variable([0.1], tf.float32)

# Output 模型的输出
linear_model = tf.identity(W1 * x + W2 * x**2 + W3 * x**3 + b,
                           name='activation_opt')

# Loss 定义损失函数
loss = tf.reduce_sum(tf.square(linear_model - y), name='loss')
# Optimizer and training step 定义优化器运算
optimizer = tf.train.AdamOptimizer(0.001)
train = optimizer.minimize(loss, name='train_step')

# Remember output operation for later aplication
# Adding it to a collections for easy acces
# This is not required if you NAME your output operation
# 记得将输出操作添加到一个集合中,但如何你命名了输出操作,这一步可以省略
tf.add_to_collection("activation", linear_model)

## Start the session ##
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#  CREATE SAVER
saver = tf.train.Saver()

# Training loop 训练
for i in range(10000):
    sess.run(train, {x: data, y: expected})
    if i % 1000 == 0:
        # You can also save checkpoints using global_step variable
        saver.save(sess, "models/model_name", global_step=i)

# SAVE TensorFlow graph into path models/model_name
# 保存模型到指定路径并命名模型文件名字
saver.save(sess, "models/model_name")
复制代码

注意,这里是第一个重点-- 对变量和运算命名 。这是为了在加载模型后可以使用指定的一些权值参数,如果不命名的话,这些变量会自动命名为类似“Placeholder_1”的名字。在复杂点的模型中,使用领域(scopes)是一个很好的做法,但这里不做展开。

总之,==重点就是为了在加载模型的时候能够调用权值参数或者某些运算操作,你必须给他们命名或者是放到一个集合中。==

当保存模型后,在指定保存模型的文件夹中就应该包含这些文件: model_name.indexmodel_name.meta 以及其他文件。如果是采用 checkpoints 后缀命名模型名字,还会有名字包含 model_name-1000 的文件,其中的数字是对应变量 global_step ,也就是当前训练迭代次数。

现在我们就可以开始加载模型了。加载模型其实很简单,我们需要的只是两个函数即可: tf.train.import_meta_graphsaver.restore() 。此外,就是提供正确的模型保存路径位置。另外,如果我们希望在不同机器使用模型,那么还需要设置参数: clear_device=True

接着,我们就可以通过之前命名的名字或者是保存到的集合名字来调用保存的运算或者是权值参数了。如果使用了领域,那么还需要包含领域的名字才行。而在实际调用这些运算的时候,还必须采用类似 {'PlaceholderName:0': data} 的输入占位符,否则会出现错误。

加载模型的代码如下:

sess = tf.Session()

# Import graph from the path and recover session
# 加载模型并恢复到会话中
saver = tf.train.import_meta_graph('models/model_name.meta', clear_devices=True)
saver.restore(sess, 'models/model_name')

# There are TWO options how to access the operation (choose one)
# 两种方法来调用指定的运算操作,选择其中一个都可以
  # FROM SAVED COLLECTION: 从保存的集合中调用
activation = tf.get_collection('activation')[0]
  # BY NAME: 采用命名的方式
activation = tf.get_default_graph.get_operation_by_name('activation_opt').outputs[0]

# Use imported graph for data
# You have to feed data as {'x:0': data}
# Don't forget on ':0' part!
# 采用加载的模型进行操作,不要忘记输入占位符
data = 50
result = sess.run(activation, {'x:0': data})
print(result)
复制代码

多个模型

上述介绍了如何加载单个模型的操作,但如何加载多个模型呢?

如果使用加载单个模型的方式去加载多个模型,那么就会出现变量冲突的错误,也无法工作。这个问题的原因是因为一个默认图的缘故。冲突的发生是因为我们将所有变量都加载到当前会话采用的默认图中。当我们采用会话的时候,我们可以通过 tf.Session(graph=MyGraph) 来指定采用不同的已经创建好的图。因此,如果我们希望加载多个模型,那么我们需要做的就是把他们加载在不同的图,然后在不同会话中使用它们。

这里,自定义一个类来完成加载指定路径的模型到一个局部图的操作。这个类还提供 run 函数来对输入数据使用加载的模型进行操作。这个类对于我是有用的,因为我总是将模型输出放到一个集合或者对它命名为 activation_opt ,并且将输入占位符命名为 x 。你可以根据自己实际应用需求对这个类进行修改和拓展。

代码如下:

import tensorflow as tf

class ImportGraph():
    """  Importing and running isolated TF graph """
    def __init__(self, loc):
        # Create local graph and use it in the session
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)
        with self.graph.as_default():
            # Import saved model from location 'loc' into local graph
            # 从指定路径加载模型到局部图中
            saver = tf.train.import_meta_graph(loc + '.meta',
                                               clear_devices=True)
            saver.restore(self.sess, loc)
            # There are TWO options how to get activation operation:
            # 两种方式来调用运算或者参数
              # FROM SAVED COLLECTION:            
            self.activation = tf.get_collection('activation')[0]
              # BY NAME:
            self.activation = self.graph.get_operation_by_name('activation_opt').outputs[0]

    def run(self, data):
        """ Running the activation operation previously imported """
        # The 'x' corresponds to name of input placeholder
        return self.sess.run(self.activation, feed_dict={"x:0": data})
      
      
### Using the class ###
# 测试样例
data = 50         # random data
model = ImportGraph('models/model_name')
result = model.run(data)
print(result)
复制代码

总结

如果你理解了 TensorFlow 的机制的话,加载多个模型并不是一件困难的事情。上述的解决方法可能不是完美的,但是它简单且快速。最后给出总结整个过程的样例代码,这是在 Jupyter notebook 上的,代码地址如下:

gist.github.com/Breta01/f20…

最后,给出文章中几个代码例子的 github 地址:

  1. Code for creating, training and saving TensorFlow model.
  2. Importing and using TensorFlow graph (model)
  3. Class for importing multiple TensorFlow graphs.
  4. Example of importing multiple TensorFlow modules

欢迎关注我的微信公众号--机器学习与计算机视觉或者扫描下方的二维码,在后台留言,和我分享你的建议和看法,指正文章中可能存在的错误,大家一起交流,学习和进步!

TensorFlow 加载多个模型的方法

以上所述就是小编给大家介绍的《TensorFlow 加载多个模型的方法》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Eric Meyer on CSS

Eric Meyer on CSS

Eric Meyer / New Riders Press / 2002-7-8 / USD 55.00

There are several other books on the market that serve as in-depth technical guides or reference books for CSS. None, however, take a more hands-on approach and use practical examples to teach readers......一起来看看 《Eric Meyer on CSS》 这本书的介绍吧!

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具

html转js在线工具
html转js在线工具

html转js在线工具