本文整理汇总了Python中tensorflow.python.saved_model.loader.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了load函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testSaveAsText
def testSaveAsText(self):
export_dir = os.path.join(
compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("astext"))
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(42, name="v")
sess.run(tf.initialize_all_variables())
self.assertEqual(42, v.eval())
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(43, name="v")
sess.run(tf.initialize_all_variables())
self.assertEqual(43, v.eval())
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
示例2: testLegacyInitOp
def testLegacyInitOp(self):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the legacy_init_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
示例3: testSaveAsText
def testSaveAsText(self):
export_dir = self._get_export_dir("test_astext")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
示例4: testCustomMainOp
def testCustomMainOp(self):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.Variable(42, name="v3")
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the main_op.
with ops.control_dependencies([main_op.main_op()]):
add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
sess.run(custom_main_op)
builder.add_meta_graph_and_variables(
sess, ["foo"], main_op=custom_main_op)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the main_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
示例5: testTrainOpGroup
def testTrainOpGroup(self):
export_dir = self._get_export_dir("test_train_op_group")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
train_op = control_flow_ops.group()
sess.run(train_op)
# TODO(karmel): remove explicit call when in the public method.
builder._add_train_op(train_op)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
示例6: testTrainOpAfterVariables
def testTrainOpAfterVariables(self):
export_dir = self._get_export_dir("test_train_op_after_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["pre_foo"])
train_op = state_ops.assign_add(v1, v2)
sess.run(train_op)
# TODO(karmel): remove explicit call when in the public method.
builder._add_train_op(train_op)
builder.add_meta_graph(["foo"])
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["pre_foo"], export_dir)
self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
示例7: testCustomSaveable
def testCustomSaveable(self):
export_dir = self._get_export_dir("custom_saveable")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
# CheckpointedOp is a key-value table that can be saved across sessions.
# The table register itself in SAVEABLE_OBJECTS collection.
v1 = saver_test_utils.CheckpointedOp(name="v1")
variables.global_variables_initializer().run()
v1.insert("k1", 3.0).run()
# Once the table is restored, we can access it through this reference.
ops.add_to_collection("table_ref", v1.table_ref)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
loader.load(sess, ["foo"], export_dir)
# Instantiate a wrapper object from the checkpointed reference.
v1 = saver_test_utils.CheckpointedOp(
name="v1", table_ref=ops.get_collection("table_ref")[0])
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())
示例8: testGraphWithoutVariables
def testGraphWithoutVariables(self):
export_dir = self._get_export_dir("test_graph_has_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with no variables.
with self.test_session(graph=ops.Graph()) as sess:
constant_5_name = constant_op.constant(5.0).name
builder.add_meta_graph_and_variables(sess, ["foo"])
# Second graph with no variables
with self.test_session(graph=ops.Graph()) as sess:
constant_6_name = constant_op.constant(6.0).name
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo".
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
b = constant_op.constant(6.0)
c = a * b
self.assertEqual(30.0, sess.run(c))
# Restore the graph with tag "bar".
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
b = constant_op.constant(5.0)
c = a * b
self.assertEqual(30.0, sess.run(c))
示例9: export_fn
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
"""A wrapper to export to SavedModel, and convert it to other formats."""
result_dir = base_strategy.export(estimator, export_dir,
checkpoint_path,
eval_result)
with ops.Graph().as_default() as graph:
with tf_session.Session(graph=graph) as sess:
saved_model_loader.load(
sess, [tag_constants.SERVING], result_dir)
# Note: This is GTFlow internal API and might change.
ensemble_model = graph.get_operation_by_name(
"ensemble_model/TreeEnsembleSerialize")
_, dfec_str = sess.run(ensemble_model.outputs)
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
dtec.ParseFromString(dfec_str)
# Export the result in the same folder as the saved model.
if convert_fn:
convert_fn(dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices),
len(sparse_int_indices), result_dir, eval_result)
feature_importances = _get_feature_importances(
dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices))
sorted_by_importance = sorted(
feature_importances.items(), key=lambda x: -x[1])
assets_dir = os.path.join(result_dir, "assets.extra")
gfile.MakeDirs(assets_dir)
with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
"w") as f:
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
return result_dir
示例10: testVariables
def testVariables(self):
export_dir = os.path.join(
compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("variables"))
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with two variables. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=tf.Graph()) as sess:
v1 = tf.Variable(1, name="v1")
v2 = tf.Variable(2, name="v2")
sess.run(tf.initialize_all_variables())
self.assertEqual(1, v1.eval())
self.assertEqual(2, v2.eval())
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with a single variable (subset of the variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v2 = tf.Variable(3, name="v2")
sess.run(tf.initialize_all_variables())
self.assertEqual(3, v2.eval())
builder.add_meta_graph(["bar"])
# Graph with a single variable (disjoint set of variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v3 = tf.Variable(4, name="v3")
sess.run(tf.initialize_all_variables())
self.assertEqual(4, v3.eval())
builder.add_meta_graph(["baz"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
self.assertEqual(len(collection_vars), 2)
self.assertEqual(1, collection_vars[0].eval())
self.assertEqual(2, collection_vars[1].eval())
# Restore the graph with tag "bar", whose variables were not saved. Only the
# subset of the variables added to the graph will be restored with the
# checkpointed value.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
self.assertEqual(len(collection_vars), 1)
self.assertEqual(2, collection_vars[0].eval())
# Try restoring the graph with tag "baz", whose variables were not saved.
# Since this graph has a disjoint set of variables from the set that was
# saved, this should raise an error.
with self.test_session(graph=tf.Graph()) as sess:
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
示例11: testClearExtraneousSavers
def testClearExtraneousSavers(self):
export_dir = os.path.join(test.get_temp_dir(),
"test_clear_extraneous_savers")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Create a variable and a Saver.
with ops.Graph().as_default() as graph:
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Add two Savers, which should be removed in
# add_meta_graph_and_variables() in favor of the locally added one.
saver1 = tf_saver.Saver()
graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
saver2 = tf_saver.Saver()
graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)
# Confirm there are two SaverDefs.
savers = graph.get_collection(ops.GraphKeys.SAVERS)
self.assertEqual(2, len(savers))
# Confirm there are two Save and two Restore ops.
save_op_names = set([x.name for x in graph.get_operations()
if x.type == "SaveV2"])
self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
save_op_names)
restore_op_names = set([x.name for x in graph.get_operations()
if x.type == "RestoreV2"])
self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]),
restore_op_names)
# The SavedModel builder adds its own Saver' for a total of three.
builder.add_meta_graph_and_variables(
sess, [tag_constants.TRAINING], clear_devices=True)
# Save the SavedModel to disk.
builder.save()
# Restore the graph.
with ops.Graph().as_default() as graph:
with self.test_session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Confirm that the reloaded graph has only one SaverDef.
savers = ops.get_collection(ops.GraphKeys.SAVERS)
self.assertEqual(1, len(savers))
# The reloaded graph should have exactly one Save and one Restore op.
save_op_names = set([x.name for x in graph.get_operations()
if x.type == "SaveV2"])
self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
restore_op_names = set([x.name for x in graph.get_operations()
if x.type == "RestoreV2"])
self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names)
示例12: testSignatureDefs
def testSignatureDefs(self):
export_dir = self._get_export_dir("test_signature_defs")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable and a single entry in the signature def map.
# SavedModel is invoked to add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
foo_signature = signature_def_utils.build_signature_def(dict(),
dict(), "foo")
builder.add_meta_graph_and_variables(
sess, ["foo"], signature_def_map={"foo_key": foo_signature})
# Graph with the same single variable and multiple entries in the signature
# def map. No weights are saved by SavedModel.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
bar_signature = signature_def_utils.build_signature_def(dict(),
dict(), "bar")
# Also, build a different SignatureDef corresponding to "foo_key" defined
# in the previous graph.
foo_new_signature = signature_def_utils.build_signature_def(dict(),
dict(),
"foo_new")
builder.add_meta_graph(
["bar"],
signature_def_map={
"bar_key": bar_signature,
"foo_key": foo_new_signature
})
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo". The single entry in the SignatureDef map
# corresponding to "foo_key" should exist.
with self.test_session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
foo_signature = foo_graph.signature_def
self.assertEqual(len(foo_signature), 1)
self.assertEqual("foo", foo_signature["foo_key"].method_name)
# Restore the graph with tag "bar". The SignatureDef map should have two
# entries. One corresponding to "bar_key" and another corresponding to the
# new value of "foo_key".
with self.test_session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
bar_signature = bar_graph.signature_def
self.assertEqual(len(bar_signature), 2)
self.assertEqual("bar", bar_signature["bar_key"].method_name)
self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
示例13: testStripDefaultAttrsInconsistentConsumerDefaults
def testStripDefaultAttrsInconsistentConsumerDefaults(self):
if ops._USE_C_API: return # TODO(skyewm): get this working
export_dir = self._get_export_dir(
"test_strip_default_attrs_no_consumer_defaults")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled. This must remove the following
# defaults for the "Complex" Op:
# o "T" : float32. (input type)
# o "Tout" : complex64. (output type)
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], strip_default_attrs=True)
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Update the Op registry to remove defaults for all attrs("T", "Tout") from
# the "Complex" OpDef.
complex_op_def = op_def_registry.get_registered_ops()["Complex"]
original_complex_op_def = op_def_pb2.OpDef()
original_complex_op_def.CopyFrom(complex_op_def)
for attr_def in complex_op_def.attr:
attr_def.ClearField("default_value")
# Loading the SavedModel via the loader must fail because the SavedModel
# does not have any attr values for the "Complex" node and the current
# op registry does not have have any default values for the "Complex" op.
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
ValueError,
"Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
loader.load(sess, ["foo"], export_dir)
# Update the Op registry to change the defaults for attr "Tout"
# (complex64 -> complex128).
complex_op_def.CopyFrom(original_complex_op_def)
for attr_def in complex_op_def.attr:
if attr_def.name == "Tout":
attr_def.default_value.type = types_pb2.DT_COMPLEX128
# Loading the SavedModel via the loader must set "Tout" attr_value for the
# "Complex" node according to the latest defaults (complex128). This is
# expected to fail the model import as there is no OpKernel registered to
# handle attrs "T" (float32) and "Tout" (complex128).
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
".*No OpKernel was registered to support Op \'Complex\' with these "
"attrs..*"):
loader.load(sess, ["foo"], export_dir)
示例14: freeze_saved_model
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
from SignatureDef when none are provided.
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
frozen_graph_def: Frozen GraphDef.
in_tensors: List of input tensors for the graph.
out_tensors: List of output tensors for the graph.
Raises:
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
if constants.ASSETS_KEY in collection_def:
raise ValueError("SavedModels with assets/ directory are not supported.")
graph = ops.Graph()
with session.Session(graph=graph) as sess:
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
set_tensor_shapes(in_tensors, input_shapes)
output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
return frozen_graph_def, in_tensors, out_tensors
示例15: _TestStaticOp
def _TestStaticOp(self, use_function_backup):
if not is_tensorrt_enabled():
return
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
self._WriteInputSavedModel(input_saved_model_dir)
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
maximum_cached_engines=2, # This is noop, added just for testing.
use_function_backup=use_function_backup)
# Test the output GraphDef.
with ops.Graph().as_default():
importer.import_graph_def(output_graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
# Run with batch size 1, the default engine embedded in the graphdef
# will be used.
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
# Test the output SavedModel
with ops.Graph().as_default():
with self.session(config=self._GetConfigProto()) as sess:
loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
# Run with batch size 1, the default engine embedded in the graphdef
# will be used.
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)