##### Copyright 2019 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 使用分布策略保存和加载模型

<table class="tfo-notebook-buttons" align="left">
  <td><a target="_blank" href="https://tensorflow.google.cn/tutorials/distribute/save_and_load"><img src="https://tensorflow.google.cn/images/tf_logo_32px.png">在 TensorFlow.org 上查看</a></td>
  <td><a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/save_and_load.ipynb"><img src="https://tensorflow.google.cn/images/colab_logo_32px.png">在 Google Colab 中运行</a></td>
  <td><a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/distribute/save_and_load.ipynb"><img src="https://tensorflow.google.cn/images/GitHub-Mark-32px.png">在 Github 上查看源代码</a></td>
  <td><a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/distribute/save_and_load.ipynb"><img src="https://tensorflow.google.cn/images/download_logo_32px.png">下载笔记本</a></td>
</table>

## 概述

本教程演示了如何在训练期间或训练之后使用 `tf.distribute.Strategy` 以 SavedModel 格式保存和加载模型。有两种用于保存和加载 Keras 模型的 API：高级（`tf.keras.Model.save` 和 `tf.keras.models.load_model`）和低级（`tf.saved_model.save` 和 `tf.saved_model.load`）。

要全面了解 SavedModel 和序列化，请阅读[已保存模型指南](../../guide/saved_model.ipynb)和 [Keras 模型序列化指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。我们从一个简单的示例开始。

小心：TensorFlow 模型是代码，对于不受信任的代码，一定要小心。请参阅[安全使用 TensorFlow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) 以了解详情。


导入依赖项：

In [2]:
import tensorflow_datasets as tfds

import tensorflow as tf


2023-11-07 23:17:51.828961: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 23:17:51.829033: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 23:17:51.830761: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


使用 TensorFlow Datasets 和 `tf.data` 加载和准备数据，并使用 `tf.distribute.MirroredStrategy` 创建模型：

In [3]:
mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


使用 `tf.keras.Model.fit` 训练模型： 

In [4]:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)

2023-11-07 23:17:58.325808: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 1/2


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


I0000 00:00:1699399085.009935  522218 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


  1/235 [..............................] - ETA: 26:27 - loss: 2.3098 - sparse_categorical_accuracy: 0.1016

  7/235 [..............................] - ETA: 2s - loss: 1.8855 - sparse_categorical_accuracy: 0.4805   

 14/235 [>.............................] - ETA: 1s - loss: 1.4637 - sparse_categorical_accuracy: 0.6303

 20/235 [=>............................] - ETA: 1s - loss: 1.2098 - sparse_categorical_accuracy: 0.6912

 26/235 [==>...........................] - ETA: 1s - loss: 1.0419 - sparse_categorical_accuracy: 0.7300

 33/235 [===>..........................] - ETA: 1s - loss: 0.9096 - sparse_categorical_accuracy: 0.7627

 39/235 [===>..........................] - ETA: 1s - loss: 0.8239 - sparse_categorical_accuracy: 0.7827

 45/235 [====>.........................] - ETA: 1s - loss: 0.7641 - sparse_categorical_accuracy: 0.7958

 52/235 [=====>........................] - ETA: 1s - loss: 0.7112 - sparse_categorical_accuracy: 0.8082























































INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).




Epoch 2/2


  1/235 [..............................] - ETA: 9s - loss: 0.1225 - sparse_categorical_accuracy: 0.9766

  8/235 [>.............................] - ETA: 1s - loss: 0.1251 - sparse_categorical_accuracy: 0.9688

 15/235 [>.............................] - ETA: 1s - loss: 0.1288 - sparse_categorical_accuracy: 0.9651

 23/235 [=>............................] - ETA: 1s - loss: 0.1264 - sparse_categorical_accuracy: 0.9660

 30/235 [==>...........................] - ETA: 1s - loss: 0.1247 - sparse_categorical_accuracy: 0.9658

 38/235 [===>..........................] - ETA: 1s - loss: 0.1215 - sparse_categorical_accuracy: 0.9655

 46/235 [====>.........................] - ETA: 1s - loss: 0.1195 - sparse_categorical_accuracy: 0.9657

 54/235 [=====>........................] - ETA: 1s - loss: 0.1155 - sparse_categorical_accuracy: 0.9669



















































<keras.src.callbacks.History at 0x7f7a0c273700>

## 保存和加载模型

现在，您已经有一个简单的模型可供使用，让我们探索保存/加载 API。有两种可用的 API：

- 高级 (Keras)：`Model.save` 和 `tf.keras.models.load_model`（`.keras` zip 存档格式）
- 低级：`tf.saved_model.save` 和 `tf.saved_model.load`（TF SavedModel 格式）


### Keras API

以下为使用 Keras API 保存和加载模型的示例：

In [5]:
keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)

恢复无 `tf.distribute.Strategy` 的模型：

In [6]:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)

Epoch 1/2


  1/235 [..............................] - ETA: 3:57 - loss: 0.0419 - sparse_categorical_accuracy: 0.9883

 12/235 [>.............................] - ETA: 1s - loss: 0.0697 - sparse_categorical_accuracy: 0.9808  

 23/235 [=>............................] - ETA: 1s - loss: 0.0730 - sparse_categorical_accuracy: 0.9808

 34/235 [===>..........................] - ETA: 0s - loss: 0.0757 - sparse_categorical_accuracy: 0.9792

 45/235 [====>.........................] - ETA: 0s - loss: 0.0738 - sparse_categorical_accuracy: 0.9793

































Epoch 2/2


  1/235 [..............................] - ETA: 7s - loss: 0.0286 - sparse_categorical_accuracy: 0.9961

 14/235 [>.............................] - ETA: 0s - loss: 0.0594 - sparse_categorical_accuracy: 0.9819

 27/235 [==>...........................] - ETA: 0s - loss: 0.0567 - sparse_categorical_accuracy: 0.9828

 40/235 [====>.........................] - ETA: 0s - loss: 0.0566 - sparse_categorical_accuracy: 0.9823

 53/235 [=====>........................] - ETA: 0s - loss: 0.0536 - sparse_categorical_accuracy: 0.9838





























<keras.src.callbacks.History at 0x7f7a6b052490>

恢复模型后，您可以继在它上面续训练，甚至不需要再次调用 `Model.compile`，因为它在保存之前已经编译。模型以 Keras zip 存档格式保存，由 `.keras` 扩展名标记。有关详情，请参阅 [Keras 保存指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。

现在，恢复模型并使用 `tf.distribute.Strategy` 对其进行训练：

In [7]:
another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)

Epoch 1/2


2023-11-07 23:18:13.264179: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2023-11-07 23:18:13.328327: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


  1/235 [..............................] - ETA: 2:05 - loss: 0.0732 - sparse_categorical_accuracy: 0.9844

  5/235 [..............................] - ETA: 2s - loss: 0.0701 - sparse_categorical_accuracy: 0.9852  

 10/235 [>.............................] - ETA: 2s - loss: 0.0726 - sparse_categorical_accuracy: 0.9793

 15/235 [>.............................] - ETA: 2s - loss: 0.0740 - sparse_categorical_accuracy: 0.9812

2023-11-07 23:18:13.954467: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2023-11-07 23:18:13.988420: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.
2023-11-07 23:18:14.002096: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.


 20/235 [=>............................] - ETA: 2s - loss: 0.0744 - sparse_categorical_accuracy: 0.9795

 25/235 [==>...........................] - ETA: 2s - loss: 0.0766 - sparse_categorical_accuracy: 0.9787

 29/235 [==>...........................] - ETA: 2s - loss: 0.0751 - sparse_categorical_accuracy: 0.9795

 34/235 [===>..........................] - ETA: 2s - loss: 0.0763 - sparse_categorical_accuracy: 0.9790

 39/235 [===>..........................] - ETA: 2s - loss: 0.0762 - sparse_categorical_accuracy: 0.9789

 43/235 [====>.........................] - ETA: 2s - loss: 0.0743 - sparse_categorical_accuracy: 0.9793

 48/235 [=====>........................] - ETA: 2s - loss: 0.0725 - sparse_categorical_accuracy: 0.9799

 52/235 [=====>........................] - ETA: 2s - loss: 0.0704 - sparse_categorical_accuracy: 0.9803













































































Epoch 2/2


  1/235 [..............................] - ETA: 10s - loss: 0.0766 - sparse_categorical_accuracy: 0.9727

  5/235 [..............................] - ETA: 3s - loss: 0.0519 - sparse_categorical_accuracy: 0.9828 

 10/235 [>.............................] - ETA: 2s - loss: 0.0627 - sparse_categorical_accuracy: 0.9816

 15/235 [>.............................] - ETA: 2s - loss: 0.0659 - sparse_categorical_accuracy: 0.9810

 20/235 [=>............................] - ETA: 2s - loss: 0.0625 - sparse_categorical_accuracy: 0.9814

 24/235 [==>...........................] - ETA: 2s - loss: 0.0604 - sparse_categorical_accuracy: 0.9818

 29/235 [==>...........................] - ETA: 2s - loss: 0.0583 - sparse_categorical_accuracy: 0.9824

 34/235 [===>..........................] - ETA: 2s - loss: 0.0578 - sparse_categorical_accuracy: 0.9823

 39/235 [===>..........................] - ETA: 2s - loss: 0.0568 - sparse_categorical_accuracy: 0.9824

 44/235 [====>.........................] - ETA: 2s - loss: 0.0563 - sparse_categorical_accuracy: 0.9826

 49/235 [=====>........................] - ETA: 2s - loss: 0.0570 - sparse_categorical_accuracy: 0.9833

 54/235 [=====>........................] - ETA: 2s - loss: 0.0566 - sparse_categorical_accuracy: 0.9835











































































正如 `Model.fit` 输出所示，`tf.distribute.Strategy` 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。 

### `tf.saved_model` API

使用较低级别的 API 保存模型类似于 Keras API：

In [8]:
model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


可以使用 `tf.saved_model.load` 进行加载。但是，由于它是一个较低级别的 API（因此用例范围更广泛），不会返回 Keras 模型。相反，它会返回一个对象，其中包含可用于进行推断的函数。例如：

In [9]:
DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

加载的对象可能包含多个函数，每个函数与一个键关联。`"serving_default"` 键是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断，请运行以下代码： 

In [10]:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))

{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[ 0.1147381 , -0.40301684,  0.04525661, ..., -0.00361927,
        -0.16612433, -0.0153662 ],
       [ 0.23993656, -0.2951342 ,  0.11567082, ..., -0.018293  ,
        -0.25642908, -0.12756145],
       [ 0.12482558, -0.1498732 , -0.02880868, ...,  0.10482813,
        -0.04954109,  0.03868026],
       ...,
       [ 0.24344432, -0.15270862,  0.07865744, ...,  0.07004395,
        -0.13115454, -0.06379548],
       [ 0.18582651, -0.06589739,  0.0035552 , ...,  0.05252723,
        -0.15672804, -0.01999891],
       [ 0.10272254, -0.19479191,  0.0195781 , ..., -0.00804898,
        -0.19664931, -0.12797697]], dtype=float32)>}


2023-11-07 23:18:20.522652: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


您还可以采用分布式方式加载和进行推断：

In [11]:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


2023-11-07 23:18:20.756679: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.






{'dense_3': PerReplica:{
  0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 1.14738092e-01, -4.03016806e-01,  4.52566147e-02,
         4.16727774e-02, -1.74182177e-01, -1.42247193e-02,
         1.34800702e-01, -3.61925364e-03, -1.66124314e-01,
        -1.53662562e-02],
       [ 2.39936545e-01, -2.95134246e-01,  1.15670875e-01,
         6.21562228e-02, -5.74412942e-03,  9.67399850e-02,
         6.38412088e-02, -1.82929561e-02, -2.56429106e-01,
        -1.27561510e-01],
       [ 1.24825627e-01, -1.49873167e-01, -2.88086776e-02,
         5.14284968e-02, -1.49795488e-01,  5.63685335e-02,
        -1.25008225e-02,  1.04828164e-01, -4.95411083e-02,
         3.86802554e-02],
       [-3.14452238e-02, -1.13862000e-01,  3.59135307e-02,
         3.37418914e-03, -9.56010595e-02,  8.71438980e-02,
         3.11656222e-02, -8.25512223e-03, -2.66045481e-02,
         4.20068875e-02],
       [ 3.10148541e-02, -1.27727151e-01,  2.81632431e-02,
         1.84415001e-02, -5.98692782e-02,  3.2342

2023-11-07 23:18:21.480277: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


调用已恢复的函数只是基于已保存模型的前向传递 (`tf.keras.Model.predict`)。如果您想继续训练加载的函数，或者将加载的函数嵌入到更大的模型中，应如何操作？通常的做法是将此加载对象封装到 Keras 层以实现此目的。幸运的是，[TF Hub](https://tensorflow.google.cn/hub) 为此提供了 [`hub.KerasLayer`](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py)，如下所示：

In [12]:
import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


2023-11-07 23:18:22.447273: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 1/2


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1


  1/235 [..............................] - ETA: 12:20 - loss: 2.2659 - sparse_categorical_accuracy: 0.1875

  8/235 [>.............................] - ETA: 1s - loss: 1.8580 - sparse_categorical_accuracy: 0.4658   

 16/235 [=>............................] - ETA: 1s - loss: 1.4582 - sparse_categorical_accuracy: 0.6265

 24/235 [==>...........................] - ETA: 1s - loss: 1.1840 - sparse_categorical_accuracy: 0.6973

 32/235 [===>..........................] - ETA: 1s - loss: 1.0027 - sparse_categorical_accuracy: 0.7411

 39/235 [===>..........................] - ETA: 1s - loss: 0.8938 - sparse_categorical_accuracy: 0.7672

 46/235 [====>.........................] - ETA: 1s - loss: 0.8114 - sparse_categorical_accuracy: 0.7873

 54/235 [=====>........................] - ETA: 1s - loss: 0.7371 - sparse_categorical_accuracy: 0.8056





















































Epoch 2/2


  1/235 [..............................] - ETA: 9s - loss: 0.0917 - sparse_categorical_accuracy: 0.9805

  8/235 [>.............................] - ETA: 1s - loss: 0.1196 - sparse_categorical_accuracy: 0.9658

 15/235 [>.............................] - ETA: 1s - loss: 0.1244 - sparse_categorical_accuracy: 0.9654

 22/235 [=>............................] - ETA: 1s - loss: 0.1203 - sparse_categorical_accuracy: 0.9664

 29/235 [==>...........................] - ETA: 1s - loss: 0.1248 - sparse_categorical_accuracy: 0.9654

 36/235 [===>..........................] - ETA: 1s - loss: 0.1237 - sparse_categorical_accuracy: 0.9659

 43/235 [====>.........................] - ETA: 1s - loss: 0.1209 - sparse_categorical_accuracy: 0.9666

 51/235 [=====>........................] - ETA: 1s - loss: 0.1179 - sparse_categorical_accuracy: 0.9678





















































在上面的示例中，TensorFlow Hub 的 `hub.KerasLayer` 可将从 `tf.saved_model.load` 加载回的结果封装到可用于构建其他模型的 Keras 层。这对于迁移学习非常实用。 

### 我应使用哪种 API？

对于保存，如果您使用的是 Keras 模型，请使用 Keras `Model.save` API，除非您需要低级 API 允许的额外控制。如果您保存的不是 Keras 模型，那么您只能选择使用较低级的 API `tf.saved_model.save`。

对于加载，您的 API 选择取决于您要从加载 API 中获得什么。如果您无法（或不想）获取 Keras 模型，请使用 `tf.saved_model.load`。否则，请使用 `tf.keras.models.load_model`。请注意，只有保存 Keras 模型后，才能恢复 Keras 模型。

可以搭配使用 API。您可以使用 `model.save` 保存 Keras 模型，并使用低级 API `tf.saved_model.load` 加载非 Keras 模型。 

In [13]:
model = get_model()

# Saving the model using Keras `Model.save`
model.save(saved_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


### 从本地设备保存/加载

在远程设备上训练的过程中从本地 I/O 设备保存和加载时（例如，使用 Cloud TPU 时），必须使用 `tf.saved_model.SaveOptions` 和 `tf.saved_model.LoadOptions` 中的选项 `experimental_io_device` 将 I/O 设备设置为 `localhost`。例如：

In [14]:
model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)

INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


### 警告

一种特殊情况是当您以某种方式创建 Keras 模型，然后在训练之前保存它们。例如：

In [15]:
class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(saved_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)





ValueError:  Model <__main__.SubclassedModel object at 0x7f797645bf10> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.


SavedModel 保存跟踪 `tf.function` 时生成的 `tf.types.experimental.ConcreteFunction` 对象（请查看[计算图和 tf.function 简介](../../guide/intro_to_graphs.ipynb)指南中的*函数何时执行跟踪？*了解更多信息）。如果您收到像这样的 `ValueError`，那是因为 `Model.save` 无法找到或创建跟踪的 `ConcreteFunction`。

**小心：**您不应在一个 `ConcreteFunction` 都没有的情况下保存模型，因为如果这样做，低级 API 将生成一个没有 `ConcreteFunction` 签名的 SavedModel（[详细了解](../../guide/saved_model.ipynb) SavedModel 格式）。例如：

In [16]:
tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures









INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


_SignatureMap({})

一般而言，模型的前向传递（`call` 方法）会在第一次调用模型时被自动跟踪，通常是通过 Keras `Model.fit` 方法。如果您设置了输入形状，例如通过将第一层设为 `tf.keras.layers.InputLayer` 或其他层类型，并将 `input_shape` 关键字参数传递给它，Keras [序贯](https://tensorflow.google.cn/guide/keras/sequential_model)和[函数式](https://tensorflow.google.cn/guide/keras/functional) API 也可以生成 `ConcreteFunction`。

要验证您的模型是否有任何跟踪的 `ConcreteFunction`，请检查 `Model.save_spec` 是否为 `None`：

In [17]:
print(my_model.save_spec() is None)

True


我们使用 `tf.keras.Model.fit` 来训练模型，可以注意到，`save_spec` 被定义并且模型保存将生效：

In [18]:
BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(saved_model_path)

Epoch 1/2


1/7 [===>..........................] - ETA: 5s - loss: 13.1256



Epoch 2/2


1/7 [===>..........................] - ETA: 0s - loss: 12.6170



False
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797651b7c0>, 140159649852656), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f797646f370>, 140159650883888), {}).


INFO:tensorflow:Assets written to: /tmp/tf_save/assets


INFO:tensorflow:Assets written to: /tmp/tf_save/assets
