模型保存和恢复
本文档介绍了如何保存和恢复变量和模型。
保存和恢复变量
TensorFlow 变量是表示程序处理中共享的持久状态的最佳方法(参阅变量了解详细信息)。此节介绍如何保存和恢复变量。请注意,Estimator 会自动保存和恢复变量(在 model_dir
中)。
tf.train.Saver
类别提供了保存和恢复模型的方法。tf.train.Saver
构造函数针对图中所有变量或指定列表的变量将 save
和 restore
op 添加到图中。Saver
对象提供了运行这些 op 的方法,指定了写入或读取检查点文件的路径。
Saver 将恢复已经在模型中定义的所有变量。如果您在不知道如何构建图的情况下加载模型(例如,如果您正在编写用于加载各种模型的通用程序),那么请阅读本文档后半部分的保存和恢复模型概述部分。
TensorFlow 将变量保存在二进制检查点文件中,简略而言,这类文件将变量名称映射到张量值。
保存变量
使用 tf.train.Saver()
创建 Saver
来管理模型中的所有变量。例如,以下代码片段展示了如何调用 tf.train.Saver.save
方法以将变量保存到检查点文件中:
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
恢复变量
tf.train.Saver
对象不仅将变量保存到检查点文件中,还将恢复变量。请注意,当您恢复变量时,您不必事先将其初始化。例如,以下代码片段展示了如何调用 tf.train.Saver.restore
方法以从检查点文件中恢复变量:
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
注意:
- 并没有名为“/tmp/model.ckpt”的实体文件。它是为检查点创建的文件名的前缀。用户只需使用前缀(而非检查点实体文件)交互。
选择要保存和恢复的变量
如果您没有向 tf.train.Saver()
传递任何参数,则 Saver 会处理图中的所有变量。每个变量都保存在创建变量时所传递的名称下。
在检查点文件中明确指定变量名称的这种做法有时会非常有用。例如,您可能已经使用名为"weights"
的变量训练了一个模型,而您想要将该变量的值恢复到名为"params"
的变量中。
有时候,仅保存或恢复模型使用的变量子集也会很有裨益。例如,您可能已经训练了一个五层的神经网络,现在您想要训练一个六层的新模型,并重用该五层的现有权重。您可以使用 Saver 只恢复这前五层的权重。
您可以向 tf.train.Saver()
的构造函数传递以下任一内容来轻松指定要保存或加载的名称和变量:
- 变量列表(将以其本身的名称保存)。
- Python 字典,其中,键是要使用的名称,键值是要管理的变量。
继续前面所示的保存/恢复示例:
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
# Add ops to save and restore only `v2` using the name "v2"
saver = tf.train.Saver({"v2": v2})
# Use the saver object normally after that.
with tf.Session() as sess:
# Initialize v1 since the saver will not.
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
注意:
如果需要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的 Saver
对象。同一个变量可以列在多个 Saver 对象中,变量的值只有在 Saver.restore()
方法运行时才会更改。
如果您仅在会话开始时恢复模型变量的子集,则必须为其他变量运行初始化 op。有关详情,请参阅 tf.variables_initializer
。
要检查某个检查点的变量,您可以使用 inspect_checkpoint
库,尤其是 print_tensors_in_checkpoint_file
函数。
默认情况下,Saver
会为每个变量使用 tf.Variable.name
属性的值。但是,当您创建一个 Saver
对象时,您可以选择为检查点文件中的变量选择名称(此为可选操作)。
检查某个检查点的变量
我们可以使用 inspect_checkpoint
库快速检查某个检查点的变量。
继续前面所示的保存/恢复示例:
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False)
# tensor_name: v1
# [ 1. 1. 1.]
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False)
# tensor_name: v2
# [-1. -1. -1. -1. -1.]
保存和恢复模型概述
如果您想保存和加载变量、图,以及图的元数据 - 简而言之,如果您想保存或恢复模型 - 我们推荐使用 SavedModel。 SavedModel 是一种与语言无关,可恢复的密封式序列化格式。SavedModel 可让较高级别的系统和工具创建、使用和变换 TensorFlow 模型。TensorFlow 提供了多种与 SavedModel 交互的机制,如 tf.saved_model API、Estimator API 和 CLI。
用于构建和加载 SavedModel 的 API
本节重点介绍用于构建和加载 SavedModel 的 API,尤其是使用较低级别的 TensorFlow API 的情形。
构建 SavedModel
我们提供了 SavedModel 构建器的 Python 实现方法。SavedModelBuilder
类别提供了保存多个 MetaGraphDef
的功能。MetaGraph 是一种数据流图,加上相关变量、资源和签名。MetaGraphDef
是 MetaGraph 的协议缓冲区的表示法。签名是一组与图有关的输入和输出。
如果需要将资源保存并写入或复制到磁盘,则可以在首次添加 MetaGraphDef
时提供这些资源。如果多个 MetaGraphDef
与同名资源相关联,则只保留首个版本。
必须使用用户指定的标签对每个添加到 SavedModel 的 MetaGraphDef
进行标注。这些标签提供了一种方法来识别要加载和恢复的特定 MetaGraphDef
,以及共享的变量和资源子集。这些标签一般会标注 MetaGraphDef
的功能(例如服务或训练),有时也会标注特定的硬件方面的信息(如 GPU)。
例如,以下代码提示了使用 SavedModelBuilder
构建 SavedModel 的典型方法:
export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph_and_variables(sess,
[tag_constants.TRAINING],
signature_def_map=foo_signatures,
assets_collection=foo_assets)
...
# Add a second MetaGraphDef for inference.
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph([tag_constants.SERVING])
...
builder.save()
在 Python 中加载 SavedModel
Python 版的 SavedModel 加载器为 SavedModel 提供加载和恢复功能。load
指令需要以下信息:
- 要在其中恢复图定义和变量的会话。
- 用于标识要加载的 MetaGraphDef 的标签。
- SavedModel 的位置(目录)。
加载后,作为特定 MetaGraphDef 的一部分,所提供的变量、资源和签名子集将恢复到提供的会话中。
export_dir = ...
...
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
...
在 C++ 中加载 SavedModel
C++ 版的 SavedModel 加载器提供了一个 API,可通过某个路径加载 SavedModel,同时允许 SessionOptions
和 RunOptions
。您必须指定与要加载的图相关联的标签。SavedModel 加载后的版本被称为 SavedModelBundle
,它包含 MetaGraphDef 和加载时所在的会话。
const string export_dir = ...
SavedModelBundle bundle;
...
LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},
&bundle);
在 TensorFlow Serving 中加载和提供 SavedModel
您可以使用 TensorFlow Serving Model Server 二进制文件轻松加载和提供 SavedModel。参阅此处的说明,了解如何安装服务器,或根据需要创建服务器。
一旦您的 Model Server 就绪,请运行以下内容:
tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path
将 port 和 model_name 标记设为您所选的值。model_base_path 标记按预期应为基本目录,每个版本的模型都放置于以数字命名的子目录中。如果您的模型只有一个版本,只需将其放在如下所示的子目录中:将模型放入 /tmp/model/0001,将 model_base_path 设为 /tmp/model
将模型的不同版本存储在共用基本目录的以数字命名的子目录中。例如,假设基本目录是 /tmp/model
。如果您的模型只有一个版本,请将其存储在 /tmp/model/0001
中。如果您的模型有两个版本,请将第二个版本存储在 /tmp/model/0002
中,以此类推。将 --model-base_path
标记设为基本目录(在本例中为 /tmp/model
)。TensorFlow Model Server 将在该基本目录的最大编号的子目录中提供模型。
标准常量
SavedModel 为各种用例构建和加载 TensorFlow 图提供了灵活性。对于最常见的用例,SavedModel 的 API 在 Python 和 C++ 中提供了一组易于重复使用且便于在各种工具中共享的常量。
标准 MetaGraphDef 标签
您可以使用标签组唯一标识保存在 SavedModel 中的 MetaGraphDef
。常用标签的子集规定如下:
标准 SignatureDef 常量
SignatureDef 是一个协议缓冲区,用于定义图所支持的计算的签名。常用的输入键、输出键和方法名称定义如下:
搭配 Estimator 使用 SavedModel
训练 Estimator
模型之后,您可能想使用该模型创建服务来接收请求并返回结果。您可以本机运行此服务,或在云端对其进行可扩展部署。
要准备一个训练完好的 Estimator 用于提供服务,您必须以标准 SavedModel 格式导出它。本节介绍如何进行以下操作:
- 指定可以提供相关服务(Classify、Regress 或 Predict)的输出节点和相应的 API。
- 将您的模型导出为 SavedModel 格式。
- 从本地服务器提供模型并请求预测。
为输入做准备
在训练期间,input_fn()
提取数据,并准备好供模型使用。在提供服务期间,类似地,serving_input_receiver_fn()
接受推理请求,并为模型做好准备。该函数具有以下用途:
- 在服务系统将使用推理请求提供的图中添加占位符。
- 添加将数据从输入格式转换为模型所预期的特征
Tensor
所需的任何额外 op。
该函数返回一个 tf.estimator.export.ServingInputReceiver
对象,该对象将占位符和生成的特征 Tensor
组合在一起。
典型的模式是推理请求以序列化 tf.Example
的形式到达,为此,serving_input_receiver_fn()
创建单个字符串占位符来接收它们。serving_input_receiver_fn()
接着也负责解析 tf.Example
(通过向图中添加 tf.parse_example
op)。
在编写 serving_input_receiver_fn()
这样的代码时,您必须将解析规范传递给 tf.parse_example
,告诉解析器哪些特征名称可能将出现以及如何将它们映射到 Tensor
。解析规范采用字典的形式,从特征名称映射到 tf.FixedLenFeature
、tf.VarLenFeature
和 tf.SparseFeature
。请注意,此解析规范不应包含任何标签或权重列,因为这些列在服务时间将不可用(与 input_fn()
在训练时使用的解析规范相反)。
然后结合如下:
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[default_batch_size],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
tf.estimator.export.build_parsing_serving_input_receiver_fn
效用函数提供了普遍情况下的输入接收器。
注意: 在使用 Predict API 和本地服务器训练要提供的模型时,并不需要解析步骤,因为该模型将接收原始特征数据。
即使您不需要解析或其他输入处理,也就是说,如果服务系统直接提供特征 Tensor
,您仍然必须提供一个 serving_input_receiver_fn()
来为特征 Tensor
创建占位符并在其中传递占位符。tf.estimator.export.build_raw_serving_input_receiver_fn
效用函数实现了这一功能。
如果这些效用函数不能满足您的需求,您可以自由编写 serving_input_receiver_fn()
。可能需要此方法的一种情况是,如果您训练的 input_fn()
包含某些必须在服务时间重演的预处理逻辑。为了减轻训练服务倾斜的风险,我们建议将这种处理封装在一个函数内,此函数随后将从 input_fn()
和 serving_input_receiver_fn()
两者中被调用。
请注意,serving_input_receiver_fn()
也决定了签名的输入部分。也就是说,在编写 serving_input_receiver_fn()
时,必须告诉解析器哪些有哪些签名可能出现,以及如何将它们映射到模型的预期输入。相反,签名的输出部分由模型决定。
执行导出
要导出已训练的 Estimator,请用出口基本路径和 serving_input_receiver_fn
调用 tf.estimator.Estimator.export_savedmodel
。
estimator.export_savedmodel(export_dir_base, serving_input_receiver_fn)
这个方法通过首先调用 serving_input_receiver_fn()
创建一个新的图,以获得特征 Tensor
,然后调用此 Estimator
的 model_fn()
,以基于这些特征生成模型图。它会重新启动 Session
,并且默认情况下会将最近的检查点恢复到它(如果需要,可以传递不同的检查点)。最后,它在给定的 export_dir_base
(即 export_dir_base/<timestamp>
)下面创建一个带时间戳的导出目录,并将 SavedModel 写入其中,其中包含从此会话中保存的单个 MetaGraphDef
。
注意:
您有责任收集旧输出的无效信息。否则,连续输出将累积在 export_dir_base
下。
指定自定义模型的输出
编写自定义 model_fn
时,必须填充 tf.estimator.EstimatorSpec
返回值的 export_outputs
元素。这是 {name: output}
描述在服务期间输出和使用的输出签名的词典。
在进行单一预测的通常情况下,该词典包含一个元素,而且 name
不重要。在一个多头模型中,每个头部都由这个词典中的一个条目表示。在这种情况下,name
是一个您所选择的字符串,用于在服务时间内请求特定头部。
每个 output
值必须是一个 ExportOutput
对象,如 tf.estimator.export.ClassificationOutput
、tf.estimator.export.RegressionOutput
或 tf.estimator.export.PredictOutput
。
这些输出类型直接映射到 TensorFlow Serving API,并确定将支持哪些请求类型。
注意: 在多头情况下,系统将为从 model_fn 返回的export_outputs
字典的每个元素生成 SignatureDef
,这些元素都以相同的键命名。这些 SignatureDef
仅在它们的输出方面有所不同,这是由于相应 ExportOutput
条目所提供的内容不同。输入始终是由 serving_input_receiver_fn
提供的。推理请求可以按名称指定头部。一个头部必须使用 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
命名,其指示在推理请求没有指定 SignatureDef
时,哪一个将被提供。
在本地提供导出的模型
对于本地部署,您可以使用 TensorFlow Serving 来提供模型,这是一个开源项目,用于加载 SavedModel 并将其作为 gRPC 服务公开。
然后构建并运行本地模型服务器,用上面导出的指向 SavedModel 的路径替换 $export_dir_base
:
bazel build //tensorflow_serving/model_servers:tensorflow_model_server
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_base_path=$export_dir_base
现在您有一台服务器在端口 9000 上通过 gRPC 监听推理请求!
从本地服务器请求预测
服务器响应 gRPC 请求时遵循 PredictionService gRPC API 服务定义(嵌套协议缓冲区在各种相邻文件中定义)。
根据 API 服务定义,gRPC 框架以各种语言生成客户端库,提供对 API 的远程访问。在使用 Bazel 构建工具的项目中,这些库是自动构建的,并通过以下关联项提供(以 Python 为例):
deps = [
"//tensorflow_serving/apis:classification_proto_py_pb2",
"//tensorflow_serving/apis:regression_proto_py_pb2",
"//tensorflow_serving/apis:predict_proto_py_pb2",
"//tensorflow_serving/apis:prediction_service_proto_py_pb2"
]
Python 客户端代码可以导入这些库:
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import regression_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
注意:prediction_service_pb2
将服务定义为一个整体,因此始终是必需的。然而一个典型的客户端只需要classification_pb2
、regression_pb2
和predict_pb2
中的一个,取决于所做请求的类型。
通过在协议缓冲区聚集请求数据并将其传递给服务存根即可完成 gRPC 请求的发送。请注意观察请求协议缓冲区是如何创建为空区的,然后是如何通过生成的协议缓冲区 API 填充的。
from grpc.beta import implementations
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
request = classification_pb2.ClassificationRequest()
example = request.input.example_list.examples.add()
example.features.feature['x'].float_list.value.extend(image[0].astype(float))
result = stub.Classify(request, 10.0) # 10 secs timeout
本例中返回的结果是一个 ClassificationResponse
协议缓冲区。
这是一个总体示例;请参阅 Tensorflow Serving 文档和示例以获取更多详细信息。
注意:ClassificationRequest
和RegressionRequest
包含一个tensorflow.serving.Input
协议缓冲区,而该缓冲区又包含多达一列表的tensorflow.Example
协议缓冲区。而与之不同的是,PredictRequest
包含从特征名称到用TensorProto
进行编码的值的映射。相应地,当使用Classify
和Regress
API 时,TensorFlow Serving 将提供序列化的tf.Example
到图中,所以您的serving_input_receiver_fn()
应该包含一个tf.parse_example()
Op。但是,当使用通用Predict
API 时,TensorFlow Serving 会将原始特征数据提供给图,因此应该使用serving_input_receiver_fn()
进行传递。
使用 CLI 检查并执行 SavedModel
您可以使用 SavedModel 命令行界面 (CLI) 检查并执行 SavedModel。例如,您可以使用 CLI 检查模型的 SignatureDef
。CLI 让您能够快速确认输入的 Tensor dtype 和形状是否与模型匹配。此外,如果您想测试模型,可以使用 CLI 通过以各种格式(如 Python 表达式)传递示例输入,然后获取输出来进行健全性检查。
安装 SavedModel CLI
一般来说,您可以通过以下任一方式安装 TensorFlow:
- 通过安装预构建的 TensorFlow 二进制文件。
- 通过从源代码构建 TensorFlow。
如果您通过预构建的 TensorFlow 二进制文件安装了 TensorFlow,则您的系统上已经安装 SavedModel CLI(路径名称为:bin\saved_model_cli
)。
如果您从源代码构建 TensorFlow,则必须运行以下附加命令来构建 saved_model_cli
:
$ bazel build tensorflow/python/tools:saved_model_cli
命令概述
SavedModel CLI 在 SavedModel 中 MetaGraphDef
上支持以下两个命令:
show
,显示在 SavedModel 中MetaGraphDef
上的计算。run
,在MetaGraphDef
上运行计算。
show
命令
SavedModel 包含一个或多个 MetaGraphDef
,由其标签集进行标识。要提供模型,您可能想知道每个模型中的 SignatureDef
是什么类型的,它们的输入和输出是什么。show
命令可让您按层次顺序检查 SavedModel 的内容。语法如下:
usage: saved_model_cli show [-h] --dir DIR [--all]
[--tag_set TAG_SET] [--signature_def SIGNATURE_DEF_KEY]
例如,以下命令显示 SavedModel 中所有可用的 MetaGraphDef 的标签集:
$ saved_model_cli show --dir /tmp/saved_model_dir
The given SavedModel contains the following tag-sets:
serve
serve, gpu
以下命令会显示 MetaGraphDef
中所有可用的 SignatureDef
键:
$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve
The given SavedModel `MetaGraphDef` contains `SignatureDefs` with the
following keys:
SignatureDef key: "classify_x2_to_y3"
SignatureDef key: "classify_x_to_y"
SignatureDef key: "regress_x2_to_y3"
SignatureDef key: "regress_x_to_y"
SignatureDef key: "regress_x_to_y2"
SignatureDef key: "serving_default"
如果 MetaGraphDef
有标签集中的多个标签,则您必须指定所有标签,用英文逗号将每个标签分隔开来。例如:
$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve,gpu
要显示特定 SignatureDef
的所有输入和输出 TensorInfo,请将 SignatureDef
键传递给 signature_def
选项。当您想知道输入张量的键值、dtype 和形状以便后续执行计算图时,这会非常有用。例如:
$ saved_model_cli show --dir \
/tmp/saved_model_dir --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y:0
Method name is: tensorflow/serving/predict
要显示 SavedModel 中的所有可用信息,请使用 --all
选项。例如:
$ saved_model_cli show --dir /tmp/saved_model_dir --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['classify_x2_to_y3']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x2:0
The given SavedModel SignatureDef contains the following output(s):
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y3:0
Method name is: tensorflow/serving/classify
...
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y:0
Method name is: tensorflow/serving/predict
run
命令
调用 run
命令以运行图计算、传递输入,然后显示(并可选地保存)输出。语法如下:
usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def
SIGNATURE_DEF_KEY [--inputs INPUTS]
[--input_exprs INPUT_EXPRS] [--outdir OUTDIR]
[--overwrite] [--tf_debug]
run
命令提供以下两种方式将输入传递给模型:
--inputs
选项可让您在文件中传递多维数组 (numpy ndarray)。--input_exprs
选项可让您传递 Python 表达式。--input_examples
选项可让您传递tf.train.Example
。
--inputs
要在文件中传递输入数据,请指定 --inputs
选项,该选项采用以下通用格式:
--inputs <INPUTS>
INPUT 采用以下格式之一:
<input_key>=<filename>
<input_key>=<filename>[<variable_name>]
您可能会传递多个 INPUT。如果您确实要传递多个输入,请使用分号分隔每个 INPUT。
saved_model_cli
使用 numpy.load
加载文件名。文件名可以是以下任何一种格式:
.npy
.npz
- pickle 格式
.npy
文件总是包含多维数组 (numpy ndarray)。因此,当从 .npy
文件加载时,内容将直接分配给指定的输入张量。如果用该 .npy
文件指定 variable_name,则 variable_name 将被忽略,并且系统会发出警告。
从 .npz
(zip) 文件加载时,您可以选择指定一个 variable_name 来标识 zip 文件中用于加载输入张量键的变量。如果您未指定 variable_name,则 SavedModel CLI 将检查 zip 文件中是否只包含一个文件,并将为指定的输入张量键加载。
从 pickle 文件加载时,如果方括号中没有指定 variable_name
,那么 pickle 文件中的任何内容都将传递到指定的输入张量键。否则,SavedModel CLI 会假设在 pickle 文件中存储了字典,并且与 variable_name对应的值将被使用。
--inputs_exprs
要通过 Python 表达式传递输入,请指定 --input_exprs
选项。这对于您目前没有数据文件的情形而言非常有用,但最好还是用一些与模型的 SignatureDef
的 dtype 和形状匹配的简单输入来检查模型。例如:
`<input_key>=[[1],[2],[3]]`
除了 Python 表达式之外,您还可以传递 numpy 函数。例如:
`<input_key>=np.ones((32,32,3))`
(请注意,numpy
模块已可作为 np
提供。)
--inputs_examples
要将 tf.train.Example
作为输入传递,请指定 --input_examples
选项。对于每个输入键,它都基于一个字典列表,其中每个字典都是 tf.train.Example
的一个实例。不同的字典键代表不同的特征,而相应的值则是每个特征的值列表。例如:
`<input_key>=[{"age":[22,24],"education":["BS","MS"]}]`
保存输出
默认情况下,SavedModel CLI 将输出写入 stdout。如果目录传递给 --outdir
选项,则输出将被保存为在指定目录下以输出张量键命名的 npy 文件。
使用 --overwrite
覆盖现有的输出文件。
TensorFlow Debugger (tfdbg) 集成
如果设置了 --tf_debug
选项,则 SavedModel CLI 将使用 TensorFlow Debugger (tfdbg) 在运行 SavedModel 时观察中间张量和运行时图或子图。
run
的完整示例
假设:
- 您的模型只需添加
x1
和x2
即可获得输出y
。 - 模型中的所有张量都具有形状
(-1, 1)
。 - 您有两个
npy
文件: /tmp/my_data1.npy
,其中包含多维数组[[1], [2], [3]]
。/tmp/my_data2.npy
,其中包含另一个多维数组[[0.5], [0.5], [0.5]]
。
要使用模型运行这两个 npy
文件以获得输出 y
,请发出以下命令:
$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npy;x2=/tmp/my_data2.npy \
--outdir /tmp/out
Result for output key y:
[[ 1.5]
[ 2.5]
[ 3.5]]
让我们稍微调整一下前面的例子。这一次,不是两个 .npy
文件,而是一个 .npz
文件和一个 pickle 文件。此外,您要覆盖任何现有的输出文件。命令如下:
$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def x1_x2_to_y \
--inputs x1=/tmp/my_data1.npz[x];x2=/tmp/my_data2.pkl --outdir /tmp/out \
--overwrite
Result for output key y:
[[ 1.5]
[ 2.5]
[ 3.5]]
您可以指定 python 表达式,取代输入文件。例如,以下命令用 Python 表达式替换输入 x2
:
$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npz[x] \
--input_exprs 'x2=np.ones((3,1))'
Result for output key y:
[[ 2]
[ 3]
[ 4]]
要在开启 TensorFlow Debugger 的情况下运行模型,请发出以下命令:
$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def serving_default --inputs x=/tmp/data.npz[x] --tf_debug
SavedModel 目录的结构
当您以 SavedModel 格式保存模型时,TensorFlow 会创建一个由以下子目录和文件组成的 SavedModel 目录:
assets/
assets.extra/
variables/
variables.data-?????-of-?????
variables.index
saved_model.pb|saved_model.pbtxt
其中:
assets
是包含辅助(外部)文件(如词汇表)的子文件夹。资源被复制到 SavedModel 的位置,并且可以在加载特定的MetaGraphDef
时读取。assets.extra
是一个子文件夹,其中较高级别的库和用户可以添加自己的资源,该资源与模型共存,但不会被图加载。此子文件夹不由 SavedModel 库管理。variables
是包含tf.train.Saver
的输出的子文件夹。saved_model.pb
或saved_model.pbtxt
是 SavedModel 协议缓冲区。它包含作为MetaGraphDef
协议缓冲区的图定义。
单个 SavedModel 可以表示多个图。在这种情况下,SavedModel 中所有图共享一组检查点(变量)和资源。例如,下图显示了一个包含三个 MetaGraphDef
的 SavedModel,它们三个都共享同一组检查点和资源:
每个图都与一组特定的标签相关联,可在加载或恢复操作期间方便您进行识别。