TensorFlow 调用预训练好的模型—— Python 实现

栏目: Python · 发布时间: 7年前

内容简介:获取更多精彩,请关注「seniusen」!
  • TensorFlow 预训练好的模型被保存为以下四个文件
TensorFlow 调用预训练好的模型—— Python 实现
  • data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如下所示,为简单起见只保留了一个模型。
model_checkpoint_path: "/home/senius/python/c_python/test/model-40"
all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"
复制代码

2. 导入模型图、参数值和相关变量

import tensorflow as tf
import numpy as np

sess = tf.Session()
X = None # input
yhat = None # output

def load_model():
    """
        Loading the pre-trained model and parameters.
    """
    global X, yhat
    modelpath = r'/home/senius/python/c_python/test/'
    saver = tf.train.import_meta_graph(modelpath + 'model-40.meta')
    saver.restore(sess, tf.train.latest_checkpoint(modelpath))
    graph = tf.get_default_graph()
    X = graph.get_tensor_by_name("X:0")
    yhat = graph.get_tensor_by_name("tanh:0")
    print('Successfully load the pre-trained model!')

复制代码
  • 通过 saver.restore 我们可以得到预训练的所有参数值,然后再通过 graph.get_tensor_by_name 得到模型的输入张量和我们想要的输出张量。

3. 运行前向传播过程得到预测值

def predict(txtdata):
    """
        Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3).
        Test a single example.
        Arg:
                txtdata: Array in C.
        Returns:
            Three coordinates of a face normal.
    """
    global X, yhat

    data = np.array(txtdata)
    data = data.reshape(-1, 41, 41, 41, 3)
    output = sess.run(yhat, feed_dict={X: data})  # (-1, 3)
    output = output.reshape(-1, 1)
    ret = output.tolist()
    return ret

复制代码
  • 通过 feed_dict 喂入测试数据,然后 run 输出的张量我们就可以得到预测值。

4. 测试

load_model()
testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32)
testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3)
testdata = testdata[0:2, ...] # the first two examples
txtdata = testdata.tolist()
output = predict(txtdata)
print(output)
#  [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523], 
# [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]
复制代码
  • 本例输入是一个三维网格模型处理后的 [41, 41, 41, 3] 的数据,输出一个表面法向量坐标 (x, y, z)。

获取更多精彩,请关注「seniusen」!

TensorFlow 调用预训练好的模型—— Python 实现

以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Open Data Structures

Open Data Structures

Pat Morin / AU Press / 2013-6 / USD 29.66

Offered as an introduction to the field of data structures and algorithms, Open Data Structures covers the implementation and analysis of data structures for sequences (lists), queues, priority queues......一起来看看 《Open Data Structures》 这本书的介绍吧!

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具