本文整理汇总了Python中tensorflow.python.eager.def_function.function函数的典型用法代码示例。如果您正苦于以下问题:Python function函数的具体用法?Python function怎么用?Python function使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了function函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_nested_functions
def test_nested_functions(self):
f = def_function.function(
lambda x: x*2.0,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
g = def_function.function(
lambda x: f(x) + 1.0,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root = tracking.AutoCheckpointable()
root.g = g
imported = self.cycle(root)
imported.g(constant_op.constant([1.0]))
示例2: test_nested_functions
def test_nested_functions(self, cycles):
f = def_function.function(
lambda x: x*2.0,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
g = def_function.function(
lambda x: f(x) + 1.0,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
root = tracking.AutoCheckpointable()
root.g = g
# TODO(vbardiovsky): Enable this test. For this to work, we must ensure that
# concrete_function._inference_function._graph._functions contains all
# functions that were on the graph before saving.
imported = self.cycle(root, 1)
imported.g(constant_op.constant([1.0]))
示例3: test_callable
def test_callable(self):
class M1(tracking.AutoCheckpointable):
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
def __call__(self, x):
return x
root = tracking.AutoCheckpointable()
root.m1 = M1()
root.m2 = tracking.AutoCheckpointable()
root.m2.__call__ = def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])(
lambda x: x*3.0)
imported = self.cycle(root)
x = constant_op.constant(1.0)
self.assertTrue(callable(imported.m1))
self.assertAllEqual(root.m1(x), imported.m1(x))
# Note: `root.m2` was not callable since `__call__` attribute was set
# into the instance and not on the class. But after a serialization cycle
# that starts to work.
self.assertTrue(callable(imported.m2))
self.assertAllEqual(root.m2.__call__(x), imported.m2(x))
# Verify that user objects without `__call__` attribute are not callable.
self.assertFalse(callable(imported))
示例4: test_functools_partial_keywords
def test_functools_partial_keywords(self):
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
self.assertAllEqual(func(), [0.0])
示例5: test_functools_partial_new_default
def test_functools_partial_new_default(self):
def f(x=3, y=7):
return x + y
func = def_function.function(functools.partial(f, y=6))
self.assertEqual(func().numpy(), 9)
self.assertEqual(func(y=8).numpy(), 11)
示例6: testConstSavedModel
def testConstSavedModel(self):
"""Test a basic model with functions to make sure functions are inlined."""
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.f = def_function.function(lambda x: 2. * x)
to_save = root.f.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save(root, save_dir, to_save)
saved_model = load(save_dir)
input_func = saved_model.signatures["serving_default"]
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(variable_graph_def))
self.assertTrue(variable_graph_def.library.function)
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(constant_graph_def.library.function)
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
示例7: testConstructConcreteFunction
def testConstructConcreteFunction(self):
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
func = root.f.get_concrete_function(input_data)
input_func = convert_to_constants._construct_concrete_function(
func, func.graph.as_graph_def())
# Test if model has enough metadata to be frozen afterwards.
variable_graph_def = input_func.graph.as_graph_def()
self.assertEqual(2, self._getNumVariables(variable_graph_def))
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
示例8: testVariableSavedModel
def testVariableSavedModel(self):
"""Test a basic model with Variables with saving/loading the SavedModel."""
input_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v1 = variables.Variable(3.)
root.v2 = variables.Variable(2.)
root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
to_save = root.f.get_concrete_function(input_data)
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save(root, save_dir, to_save)
saved_model = load(save_dir)
input_func = saved_model.signatures["serving_default"]
variable_graph_def = input_func.graph.as_graph_def()
self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))
output_func = convert_to_constants.convert_variables_to_constants_v2(
input_func)
constant_graph_def = output_func.graph.as_graph_def()
self.assertEqual(0, self._getNumVariables(constant_graph_def))
self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))
# Check value.
expected_value = root.f(input_data)
actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
[input_data.numpy()])
self.assertEqual(expected_value.numpy(), actual_value)
示例9: add_metric_step
def add_metric_step(defun):
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
model = testing_utils.get_model_from_layers([
LayerWithMetrics(),
keras.layers.Dense(1, kernel_initializer='zeros', activation='softmax')
],
input_shape=(10,))
def train_step(x, y):
with backprop.GradientTape() as tape:
y_pred_1 = model(x)
y_pred_2 = model(2 * x)
y_pred = y_pred_1 + y_pred_2
loss = keras.losses.mean_squared_error(y, y_pred)
gradients = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
assert len(model.metrics) == 2
return [m.result() for m in model.metrics]
if defun:
train_step = def_function.function(train_step)
x, y = array_ops.ones((10, 10)), array_ops.zeros((10, 1))
metrics = train_step(x, y)
assert np.allclose(metrics[0], 1.5)
assert np.allclose(metrics[1], 1.5)
return metrics
示例10: testDecorate
def testDecorate(self):
func = def_function.function(lambda: 1)
def decorator(f):
return lambda: 1 + f()
func._decorate(decorator)
self.assertEqual(func().numpy(), 2)
示例11: test_functools_partial_single_positional
def test_functools_partial_single_positional(self):
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, constant_op.constant(1)))
self.assertAllEqual(func(5), 6)
示例12: test_structured_output
def test_structured_output(self):
# Use fields with non-alphabetical order
named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])
def func(input1, input2):
named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
return [named_tuple, input2, {"x": 0.5}]
root = tracking.AutoCheckpointable()
root.f = def_function.function(func)
result = root.f(constant_op.constant(2), constant_op.constant(3))
self.assertEqual(5, result[0].a.numpy())
self.assertEqual(6, result[0].b.numpy())
self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
self.assertEqual(3, result[1].numpy())
self.assertEqual(0.5, result[2]["x"].numpy())
imported = self.cycle(root)
result = imported.f(constant_op.constant(2), constant_op.constant(5))
self.assertEqual(7, result[0].a.numpy())
self.assertEqual(10, result[0].b.numpy())
self.assertEqual(["b", "a"], list(result[0]._asdict().keys()))
self.assertEqual(5, result[1].numpy())
self.assertEqual(0.5, result[2]["x"].numpy())
示例13: test_table
def test_table(self):
initializer = lookup_ops.TextFileInitializer(
self._vocab_path,
key_dtype=dtypes.string,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
root = util.Checkpoint(table=lookup_ops.HashTable(
initializer, default_value=-1))
root.table_user = def_function.function(
root.table.lookup,
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
self.assertEqual(
2,
self.evaluate(root.table_user(constant_op.constant("gamma"))))
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
save.save(root, save_dir)
file_io.delete_file(self._vocab_path)
self.assertAllClose(
{"output_0": [2, 0]},
_import_and_infer(save_dir, {"keys": ["gamma", "alpha"]}))
second_dir = os.path.join(self.get_temp_dir(), "second_dir")
# Asset paths should track the location the SavedModel is loaded from.
file_io.rename(save_dir, second_dir)
self.assertAllClose(
{"output_0": [2, 1]},
_import_and_infer(second_dir, {"keys": ["gamma", "beta"]}))
示例14: _apply_all_reduce
def _apply_all_reduce(reduction, tensors):
"""Helper function for all_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations')
shared_name = _get_shared_name()
def _all_reduce():
"""Call nccl allreduce."""
res = []
for t in tensors:
_check_device(t)
with ops.device(t.device):
res.append(
gen_nccl_ops.nccl_all_reduce(
input=t,
reduction=reduction,
num_devices=len(tensors),
shared_name=shared_name))
return res
if context.executing_eagerly():
# Nccl ops will block unless they are executed concurrently such as in a
# graph or a defun.
return def_function.function(_all_reduce)()
else:
return _all_reduce()
示例15: test_structured_inputs
def test_structured_inputs(self):
def func(x, training=True):
# x is a nested structure, we care about one particular tensor.
_, (a, b) = x
if training:
return 2 * a["a"] + b
else:
return 7
root = tracking.AutoCheckpointable()
root.f = def_function.function(func)
x = constant_op.constant(10)
y = constant_op.constant(11)
input1 = [6, ({"a": x}, y)]
input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature.
input3 = [6, ({"a": y}, x)] # Compatible with input1 signature.
# Note: by only calling f(input1) before serialization, only inputs with
# matching signature will be valid on the loaded model.
self.assertEqual(31, root.f(input1).numpy())
imported = self.cycle(root)
with self.assertRaisesRegexp(AssertionError,
"Could not find matching function to call.*"):
imported.f(input2)
self.assertEqual(31, imported.f(input1).numpy())
self.assertEqual(32, imported.f(input3).numpy())