- 用于创建模型的代码,以及
- 模型的训练权重或参数
注意:请谨慎使用不可信的代码 - TensorFlow 模型就是代码。有关详情,请参阅安全地使用 TensorFlow。
您可以通过多种不同的方法保存 TensorFlow 模型,具体取决于您使用的 API。本指南使用的是 tf.keras,它是一种用于在 TensorFlow 中构建和训练模型的高阶 API。要了解其他方法,请参阅 TensorFlow 保存和恢复指南或在 Eager 中保存。
安装并导入 TensorFlow 和依赖项:
我们将使用 MNIST 数据集训练模型,以演示如何保存权重。要加快演示运行速度,请仅使用前 1000 个样本:
from __future__ import absolute_import, division, print_function import os import tensorflow as tf from tensorflow import keras tf.__version__ 复制代码
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000] test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 复制代码
# Returns a short sequential model def create_model(): model = tf.keras.models.Sequential([ keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)), keras.layers.Dropout(0.2), keras.layers.Dense(10, activation=tf.nn.softmax) ]) model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['accuracy']) return model # Create a basic model instance model = create_model() model.summary() 复制代码
WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\keras\layers\core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version. Instructions for updating: Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`. _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 _________________________________________________________________ dropout (Dropout) (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ 复制代码
tf.keras.callbacks.ModelCheckpoint 是执行此任务的回调。该回调需要几个参数来配置检查点。
训练模型,并将 ModelCheckpoint 回调传递给该模型:
checkpoint_path = "training_1/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) # Create checkpoint callback cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1) model = create_model() model.fit(train_images, train_labels, epochs = 10, validation_data = (test_images,test_labels), callbacks = [cp_callback]) # pass callback to training 复制代码
Train on 1000 samples, validate on 1000 samples Epoch 1/10 864/1000 [========================>.....] - ETA: 0s - loss: 1.2590 - acc: 0.6354 Epoch 00001: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. WARNING:tensorflow:From e:\program files\python37\lib\site-packages\tensorflow\python\keras\engine\network.py:1436: update_checkpoint_state (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use tf.train.CheckpointManager to manage checkpoints rather than manually editing the Checkpoint proto. 1000/1000 [==============================] - 1s 791us/sample - loss: 1.1675 - acc: 0.6650 - val_loss: 0.7683 - val_acc: 0.7550 Epoch 2/10 896/1000 [=========================>....] - ETA: 0s - loss: 0.4623 - acc: 0.8750 Epoch 00002: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 351us/sample - loss: 0.4515 - acc: 0.8750 - val_loss: 0.5316 - val_acc: 0.8340 Epoch 3/10 800/1000 [=======================>......] - ETA: 0s - loss: 0.2790 - acc: 0.9287 Epoch 00003: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 358us/sample - loss: 0.2834 - acc: 0.9270 - val_loss: 0.4607 - val_acc: 0.8520 Epoch 4/10 928/1000 [==========================>...] - ETA: 0s - loss: 0.2077 - acc: 0.9515 Epoch 00004: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 339us/sample - loss: 0.2046 - acc: 0.9530 - val_loss: 0.4370 - val_acc: 0.8540 Epoch 5/10 896/1000 [=========================>....] - ETA: 0s - loss: 0.1578 - acc: 0.9710 Epoch 00005: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 350us/sample - loss: 0.1526 - acc: 0.9720 - val_loss: 0.4047 - val_acc: 0.8670 Epoch 6/10 864/1000 [========================>.....] - ETA: 0s - loss: 0.1055 - acc: 0.9815 Epoch 00006: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 380us/sample - loss: 0.1062 - acc: 0.9830 - val_loss: 0.4201 - val_acc: 0.8560 Epoch 7/10 864/1000 [========================>.....] - ETA: 0s - loss: 0.0826 - acc: 0.9850 Epoch 00007: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 351us/sample - loss: 0.0824 - acc: 0.9850 - val_loss: 0.4168 - val_acc: 0.8660 Epoch 8/10 864/1000 [========================>.....] - ETA: 0s - loss: 0.0662 - acc: 0.9919 Epoch 00008: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 357us/sample - loss: 0.0655 - acc: 0.9910 - val_loss: 0.4021 - val_acc: 0.8700 Epoch 9/10 864/1000 [========================>.....] - ETA: 0s - loss: 0.0495 - acc: 0.9954 Epoch 00009: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 358us/sample - loss: 0.0491 - acc: 0.9950 - val_loss: 0.4168 - val_acc: 0.8640 Epoch 10/10 896/1000 [=========================>....] - ETA: 0s - loss: 0.0401 - acc: 1.0000 Epoch 00010: saving model to training_1/cp.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x000000001357DD30>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 354us/sample - loss: 0.0397 - acc: 1.0000 - val_loss: 0.4091 - val_acc: 0.8770 复制代码
上述代码将创建一个 TensorFlow 检查点文件集合,这些文件在每个周期结束时更新:
驱动器 C 中的卷没有标签。 卷的序列号是 CE2F-63AD C:\Users\Administrator\JupyterProject\training_1 的目录 2019/04/28 11:23 <DIR> . 2019/04/28 11:23 <DIR> .. 2019/04/28 11:23 71 checkpoint 2019/04/28 11:23 1,631,508 cp.ckpt.data-00000-of-00001 2019/04/28 11:23 648 cp.ckpt.index 3 个文件 1,632,227 字节 2 个目录 23,484,948,480 可用字节 复制代码
现在,重新构建一个未经训练的全新模型,并用测试集对其进行评估。未训练模型的表现有很大的偶然性(准确率约为 10%):
model = create_model() loss, acc = model.evaluate(test_images, test_labels) print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) 复制代码
1000/1000 [==============================] - 0s 81us/sample - loss: 2.3694 - acc: 0.0610 Untrained model, accuracy: 6.10% 复制代码
model.load_weights(checkpoint_path) loss,acc = model.evaluate(test_images, test_labels) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) 复制代码
1000/1000 [==============================] - 0s 46us/sample - loss: 0.4091 - acc: 0.8770 Restored model, accuracy: 87.70% 复制代码
训练一个新模型,每隔 5 个周期保存一次检查点并设置唯一名称:
# include the epoch in the file name. (uses `str.format`) checkpoint_path = "training_2/cp-{epoch:04d}.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) cp_callback = tf.keras.callbacks.ModelCheckpoint( checkpoint_path, verbose=1, save_weights_only=True, # Save weights, every 5-epochs. period=5) model = create_model() model.fit(train_images, train_labels, epochs = 50, callbacks = [cp_callback], validation_data = (test_images,test_labels), verbose=0) 复制代码
Epoch 00005: saving model to training_2/cp-0005.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00010: saving model to training_2/cp-0010.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00015: saving model to training_2/cp-0015.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00020: saving model to training_2/cp-0020.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00025: saving model to training_2/cp-0025.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00030: saving model to training_2/cp-0030.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00035: saving model to training_2/cp-0035.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00040: saving model to training_2/cp-0040.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00045: saving model to training_2/cp-0045.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. Epoch 00050: saving model to training_2/cp-0050.ckpt WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000014CE31D0>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 复制代码
驱动器 C 中的卷没有标签。 卷的序列号是 CE2F-63AD C:\Users\Administrator\JupyterProject\training_2 的目录 2019/04/28 11:24 <DIR> . 2019/04/28 11:24 <DIR> .. 2019/04/28 11:24 81 checkpoint 2019/04/28 11:24 1,631,508 cp-0005.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0005.ckpt.index 2019/04/28 11:24 1,631,508 cp-0010.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0010.ckpt.index 2019/04/28 11:24 1,631,508 cp-0015.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0015.ckpt.index 2019/04/28 11:24 1,631,508 cp-0020.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0020.ckpt.index 2019/04/28 11:24 1,631,508 cp-0025.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0025.ckpt.index 2019/04/28 11:24 1,631,508 cp-0030.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0030.ckpt.index 2019/04/28 11:24 1,631,508 cp-0035.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0035.ckpt.index 2019/04/28 11:24 1,631,508 cp-0040.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0040.ckpt.index 2019/04/28 11:24 1,631,508 cp-0045.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0045.ckpt.index 2019/04/28 11:24 1,631,508 cp-0050.ckpt.data-00000-of-00001 2019/04/28 11:24 648 cp-0050.ckpt.index 21 个文件 16,321,641 字节 2 个目录 23,468,404,736 可用字节 复制代码
latest = tf.train.latest_checkpoint(checkpoint_dir) latest 复制代码
注意:默认的 TensorFlow 格式仅保存最近的 5 个检查点。
model = create_model() model.load_weights(latest) loss, acc = model.evaluate(test_images, test_labels) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) 复制代码
1000/1000 [==============================] - 0s 86us/sample - loss: 0.4830 - acc: 0.8770 Restored model, accuracy: 87.70% 复制代码
如果您仅在一台机器上训练模型,则您将有 1 个后缀为 .data-00000-of-00001 的分片
手动保存权重的方法同样也很简单,只需使用 Model.save_weights 方法即可。
# Save the weights model.save_weights('./checkpoints/my_checkpoint') # Restore the weights model = create_model() model.load_weights('./checkpoints/my_checkpoint') loss,acc = model.evaluate(test_images, test_labels) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) 复制代码
WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x0000000018D9D080>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved. Consider using a TensorFlow optimizer from `tf.train`. 1000/1000 [==============================] - 0s 88us/sample - loss: 0.4830 - acc: 0.8770 Restored model, accuracy: 87.70% 复制代码
在 Keras 中保存完全可正常使用的模型非常有用,您可以在 TensorFlow.js 中加载它们,然后在网络浏览器中训练和运行它们。
Keras 使用 HDF5 标准提供基本的保存格式。对于我们来说,可将保存的模型视为一个二进制 blob。
model = create_model() model.fit(train_images, train_labels, epochs=5) # Save entire model to a HDF5 file model.save('my_model.h5') 复制代码
Epoch 1/5 1000/1000 [==============================] - 0s 322us/sample - loss: 1.1511 - acc: 0.6830 Epoch 2/5 1000/1000 [==============================] - 0s 235us/sample - loss: 0.4189 - acc: 0.8840s - loss: 0.4545 - acc: 0.8 Epoch 3/5 1000/1000 [==============================] - 0s 235us/sample - loss: 0.2864 - acc: 0.9230 Epoch 4/5 1000/1000 [==============================] - 0s 233us/sample - loss: 0.2147 - acc: 0.9410 Epoch 5/5 1000/1000 [==============================] - 0s 224us/sample - loss: 0.1642 - acc: 0.9660 复制代码
# Recreate the exact same model, including weights and optimizer. new_model = keras.models.load_model('my_model.h5') new_model.summary() 复制代码
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_14 (Dense) (None, 512) 401920 _________________________________________________________________ dropout_7 (Dropout) (None, 512) 0 _________________________________________________________________ dense_15 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ 复制代码
loss, acc = new_model.evaluate(test_images, test_labels) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) 复制代码
1000/1000 [==============================] - 0s 99us/sample - loss: 0.4258 - acc: 0.8530 Restored model, accuracy: 85.30% 复制代码
- 权重值
- 模型配置(架构)
- 优化器配置
Keras 通过检查架构来保存模型。目前,它无法保存 TensorFlow 优化器(来自 tf.train)。使用此类优化器时,您需要在加载模型后对其进行重新编译,使优化器的状态变松散。
这些就是使用 tf.keras 保存和加载模型的快速指南。
tf.keras 指南详细介绍了如何使用 tf.keras 保存和加载模型。
请参阅在 Eager 中保存,了解如何在 Eager Execution 期间保存模型。
保存和恢复指南介绍了有关 TensorFlow 保存的低阶详细信息。
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
余洪春 / 机械工业出版社华章公司 / 2011-11-1 / 79.00元
资深Linux/Unix系统管理专家兼架构师多年一线工作经验结晶,51CTO和ChinaUnix等知名社区联袂推荐。结合实际生产环境,从Linux虚拟化、集群、服务器故障诊断与排除、系统安全性等多角度阐述构建高可用Linux服务器的最佳实践。本书实践性非常强,包含大量企业级的应用案例及相应的解决方案,读者可以直接用这些方案解决在实际工作中遇到的问题。 全书一共10章。第1章以作者的项目实践为......一起来看看 《构建高可用Linux服务器》 这本书的介绍吧!