本文整理汇总了Python中tensorflow.python.eager.context.in_graph_mode函数的典型用法代码示例。如果您正苦于以下问题:Python in_graph_mode函数的具体用法?Python in_graph_mode怎么用?Python in_graph_mode使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了in_graph_mode函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testAddWeight
def testAddWeight(self):
layer = base_layers.Layer(name='my_layer')
# Test basic variable creation.
variable = layer.add_variable(
'my_var', [2, 2], initializer=init_ops.zeros_initializer())
self.assertEqual(variable.name, 'my_layer/my_var:0')
self.assertListEqual(layer.variables, [variable])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
self.assertListEqual(
layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Test non-trainable variable creation.
# layer.add_variable should work even outside `build` and `call`.
variable_2 = layer.add_variable(
'non_trainable_var', [2, 2],
initializer=init_ops.zeros_initializer(),
trainable=False)
self.assertListEqual(layer.variables, [variable, variable_2])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [variable_2])
if context.in_graph_mode():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
variable = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
示例2: testAddVariable
def testAddVariable(self):
obj = NonLayerCheckpointable()
with self.assertRaisesRegexp(ValueError, "do not specify shape"):
checkpointable_utils.add_variable(
obj, name="shape_specified_twice", shape=[], initializer=1)
constant_initializer = checkpointable_utils.add_variable(
obj, name="constant_initializer", initializer=1)
with variable_scope.variable_scope("some_variable_scope"):
ones_initializer = checkpointable_utils.add_variable(
obj,
name="ones_initializer",
shape=[2],
initializer=init_ops.ones_initializer(dtype=dtypes.float32))
bare_initializer = checkpointable_utils.add_variable(
obj,
name="bare_initializer",
shape=[2, 2],
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
# Even in graph mode, there are no naming conflicts between objects, only
# naming conflicts within an object.
other_duplicate = resource_variable_ops.ResourceVariable(
name="duplicate", initial_value=1.)
duplicate = checkpointable_utils.add_variable(
obj, name="duplicate", shape=[])
with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
checkpointable_utils.add_variable(obj, name="duplicate", shape=[])
if context.in_graph_mode():
self.evaluate(variables.global_variables_initializer())
self.assertEqual("constant_initializer:0", constant_initializer.name)
self.assertEqual(1, self.evaluate(constant_initializer))
self.assertEqual("some_variable_scope/ones_initializer:0",
ones_initializer.name)
self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
self.assertAllEqual([[0., 0.],
[0., 0.]], self.evaluate(bare_initializer))
self.assertEqual("a_variable:0", obj.a_variable.name)
self.assertEqual("duplicate:0", other_duplicate.name)
if context.in_graph_mode():
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
else:
# When executing eagerly, there's no uniquification of variable names. The
# checkpoint name will be the same.
self.assertEqual("duplicate:0", duplicate.name)
named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
"bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
"constant_initializer/.ATTRIBUTES/VARIABLE_VALUE",
"duplicate/.ATTRIBUTES/VARIABLE_VALUE",
"ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
)
six.assertCountEqual(
self, expected_checkpoint_names, named_variables.keys())
示例3: testDeferredSlotRestoration
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
root = checkpointable.Checkpointable()
root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.)
optimizer = CheckpointableAdam(0.1)
if context.in_graph_mode():
train_op = optimizer.minimize(root.var)
self.evaluate(variables.global_variables_initializer())
self.evaluate(train_op)
else:
optimizer.minimize(root.var.read_value)
self.evaluate(state_ops.assign(root.var, 12.))
no_slots_path = checkpointable_utils.Saver(root).save(
os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
14.))
slots_path = checkpointable_utils.Saver(root).save(
os.path.join(checkpoint_directory, "with_slots"))
new_root = checkpointable.Checkpointable()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
slot_status = checkpointable_utils.Saver(new_root).restore(slots_path)
no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
new_root.var = checkpointable_utils.add_variable(
new_root, name="var", shape=[])
no_slot_status.assert_consumed()
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = CheckpointableAdam(0.1)
with self.assertRaisesRegexp(AssertionError, "beta1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.in_eager_mode():
# Slot variables are only created with restoring initializers when
# executing eagerly.
self.assertEqual(14., self.evaluate(
new_root.optimizer.get_slot(name="m", var=new_root.var)))
else:
self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
None)
if context.in_graph_mode():
train_op = new_root.optimizer.minimize(new_root.var)
# The slot variable now exists; restore() didn't create it, but we should
# now have a restore op for it.
slot_status.run_restore_ops()
self.assertEqual(14., self.evaluate(
new_root.optimizer.get_slot(name="m", var=new_root.var)))
self.evaluate(train_op)
else:
new_root.optimizer.minimize(new_root.var.read_value)
slot_status.assert_consumed()
示例4: testActivation
def testActivation(self):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
if context.in_graph_mode():
self.assertEqual(outputs.op.name, 'dense1/Relu')
dense = core_layers.Dense(2, name='dense2')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
if context.in_graph_mode():
self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
示例5: test_variable_reuse_exception_nested
def test_variable_reuse_exception_nested(self):
with test_util.IsolateTest(), session.Session():
first_container_variable = resource_variable_ops.ResourceVariable(
name="first_container_variable",
initial_value=1)
if context.in_graph_mode():
self.evaluate([variables.global_variables_initializer()])
with test_util.IsolateTest(), session.Session():
if context.in_graph_mode():
with self.assertRaises(RuntimeError):
self.evaluate(first_container_variable.read_value())
else:
with self.assertRaises(ValueError):
first_container_variable.read_value()
示例6: test_name_scopes_for_variable_scopes
def test_name_scopes_for_variable_scopes(self):
# Test that name scopes are not unnecessarily uniquified (but are
# still uniquified when necessary).
def linear_module(x, output_size):
w = variable_scope.get_variable(
"w", shape=[x.get_shape()[1], output_size],
initializer=init_ops.zeros_initializer())
b = variable_scope.get_variable(
"b", shape=[output_size],
initializer=init_ops.zeros_initializer())
return (math_ops.matmul(x, w) + b), w
def make_linear_module(output_size, name):
return template.make_template(
name,
linear_module,
output_size=output_size,
create_scope_now_=True)
inputs = array_ops.ones((3, 4))
linear1 = make_linear_module(output_size=2, name="foo")
outputs_a, w1 = linear1(inputs)
outputs_b, _ = linear1(inputs)
self.assertEquals("foo", linear1.variable_scope.name)
self.assertEquals("foo/w:0", w1.name)
if context.in_graph_mode():
self.assertEquals("foo/add:0", outputs_a.name,
"First application of template should get "
"same name scope as variables.")
self.assertEquals("foo_1/add:0", outputs_b.name,
"Second application of template should get "
"a freshly uniquified name scope.")
linear2 = make_linear_module(output_size=2, name="foo")
outputs_c, w2 = linear2(inputs)
outputs_d, _ = linear2(inputs)
self.assertEquals("foo_1", linear2.variable_scope.name,
"New template gets a freshly uniquified variable scope "
"because 'foo' is already taken.")
self.assertEquals("foo_1/w:0", w2.name)
if context.in_graph_mode():
self.assertEquals("foo_1_1/add:0", outputs_c.name,
"First application of template would get "
"same name scope as variables, but 'foo_1' is already "
"a name scope.")
self.assertEquals("foo_1_2/add:0", outputs_d.name,
"Second application of template should also get "
"a freshly uniquified name scope.")
示例7: __call__
def __call__(self, *args):
"""Executes the passed function in eager mode."""
tensor_inputs = [
x for x in nest.flatten(args)
if isinstance(x, ops.Tensor)
]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
if not self._has_backprop:
self._compute_backprop()
return self._backprop_call(tensor_inputs)
if context.in_graph_mode():
g = ops.get_default_graph()
if self._fdef.name not in g._functions: # pylint: disable=protected-access
g._add_function(self._fdef) # pylint: disable=protected-access
signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs)
return self._build_call_outputs(self._returns, result)
示例8: test_no_sharing
def test_no_sharing(self):
with test_util.IsolateTest(), session.Session():
first_container_variable = resource_variable_ops.ResourceVariable(
name="same_name",
initial_value=1)
if context.in_graph_mode():
self.evaluate([variables.global_variables_initializer()])
with test_util.IsolateTest(), session.Session():
second_container_variable = resource_variable_ops.ResourceVariable(
name="same_name",
initial_value=2)
if context.in_graph_mode():
self.evaluate([variables.global_variables_initializer()])
self.assertEqual(
2, self.evaluate(second_container_variable.read_value()))
self.assertEqual(1, self.evaluate(first_container_variable.read_value()))
示例9: _init_from_proto
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode.
assert context.in_graph_mode()
self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
raise ValueError("Trying to restore Variable as ResourceVariable.")
# Create from variable_def.
g = ops.get_default_graph()
self._handle = g.as_graph_element(
ops.prepend_name_scope(
variable_def.variable_name, import_scope=import_scope))
self._handle_device = self._handle.device
self._handle_name = self._handle.name
self._initializer_op = g.as_graph_element(
ops.prepend_name_scope(
variable_def.initializer_name, import_scope=import_scope))
if variable_def.snapshot_name:
self._cached_value = g.as_graph_element(
ops.prepend_name_scope(
variable_def.snapshot_name, import_scope=import_scope))
else:
self._cached_value = None
if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = variables.Variable.SaveSliceInfo(
save_slice_info_def=variable_def.save_slice_info_def)
else:
self._save_slice_info = None
self._caching_device = None
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
self._graph_element = self.value()
self._constraint = None
示例10: new_func
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
# TODO(apassos) figure out a way to have reasonable performance with
# deprecation warnings and eager mode.
if context.in_graph_mode() and _PRINT_DEPRECATION_WARNINGS:
invalid_args = []
named_args = tf_inspect.getcallargs(func, *args, **kwargs)
for arg_name, spec in iter(deprecated_positions.items()):
if (spec.position < len(args) and
not (spec.has_ok_value and
_same_value(named_args[arg_name], spec.ok_value))):
invalid_args.append(arg_name)
if is_varargs_deprecated and len(args) > len(arg_spec.args):
invalid_args.append(arg_spec.varargs)
if is_kwargs_deprecated and kwargs:
invalid_args.append(arg_spec.keywords)
for arg_name in deprecated_arg_names:
if (arg_name in kwargs and
not (deprecated_positions[arg_name].has_ok_value and
_same_value(named_args[arg_name],
deprecated_positions[arg_name].ok_value))):
invalid_args.append(arg_name)
for arg_name in invalid_args:
if (func, arg_name) not in _PRINTED_WARNING:
if warn_once:
_PRINTED_WARNING[(func, arg_name)] = True
logging.warning(
'From %s: calling %s (from %s) with %s is deprecated and will '
'be removed %s.\nInstructions for updating:\n%s',
_call_location(), decorator_utils.get_qualified_name(func),
func.__module__, arg_name,
'in a future version' if date is None else ('after %s' % date),
instructions)
return func(*args, **kwargs)
示例11: testRandomSeed
def testRandomSeed(self):
test_cases = [
# Each test case is a tuple with input to get_seed:
# (input_graph_seed, input_op_seed)
# and output from get_seed:
# (output_graph_seed, output_op_seed)
((None, None), (None, None)),
((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)),
((1, 1), (1, 1)),
((0, 0), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output
((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either
((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument
]
if context.in_graph_mode():
# 0 will be the default_graph._lastid.
test_cases.append(((1, None), (1, 0)))
else:
# operation seed is random number generated based on global seed.
# it's not tested due to possibility of platform or version difference.
pass
for tc in test_cases:
tinput, toutput = tc[0], tc[1]
random_seed.set_random_seed(tinput[0])
g_seed, op_seed = random_seed.get_seed(tinput[1])
msg = 'test_case = {0}, got {1}, want {2}'.format(tinput,
(g_seed, op_seed),
toutput)
self.assertEqual((g_seed, op_seed), toutput, msg=msg)
random_seed.set_random_seed(None)
示例12: __init__
def __init__(self, handle, dtype, handle_device, # pylint: disable=super-init-not-called
shape, in_graph_mode, deleter, parent_op):
# We do not call super init on purpose.
self._trainable = False
self._save_slice_info = None
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
self._in_graph_mode = in_graph_mode
self._handle = handle
self._handle_device = handle_device
self._shape = shape
self._initial_value = None
if isinstance(self._handle, ops.EagerTensor):
self._handle_name = ""
else:
self._handle_name = self._handle.name
self._dtype = dtype
self._constraint = None
self._cached_value = None
self._is_initialized_op = None
self._initializer_op = None
self._parent_op = parent_op
if context.in_graph_mode():
self._graph_element = self.read_value()
else:
self._graph_element = None
self._handle_deleter = deleter
示例13: testMaskingSingleInput
def testMaskingSingleInput(self):
class MaskedLayer(base_layers.Layer):
def call(self, inputs, mask=None):
if mask is not None:
return inputs * mask
return inputs
def compute_mask(self, inputs, mask=None):
return array_ops.ones_like(inputs)
if context.in_graph_mode():
x = base_layers.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable
network = base_layers.Network(x, y)
# test callability on Input
x_2 = base_layers.Input(shape=(32,))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
# test callability on regular tensor
x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
else:
a = constant_op.constant([2] * 32)
mask = constant_op.constant([0, 1] * 16)
a._keras_mask = mask
b = MaskedLayer().apply(a)
self.assertTrue(hasattr(b, '_keras_mask'))
self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
示例14: _delay_checks
def _delay_checks(self):
"""Context manager for combining checks depending on tensor evaluations.
Each call to Session.run has some overhead, and this overhead can easily
account for the majority of the time spent in tests that call Session.run
(or Tensor.eval) many times.
This context manager provides a mechanism for registering callback functions
and associated tensors. When the context is exited, all of the tensors
associated with all of the registrations are evaluated with a single call to
Session.run, and then each registered callback function is called with the
values of its associated tensors.
Yields:
A function `add_check(check, *args, **kwargs)` where `check` is the
callback function to be invoked, and `*args` and `**kwargs` specify the
associated Tensors. When in EAGER mode, check is executed in add_check,
otherwise, it's delayed after the context.
"""
checks = []
def add_check(check, *args, **kwargs):
if context.in_eager_mode():
args_val, kwargs_val = self.evaluate([args, kwargs])
check(*args_val, **kwargs_val)
else:
checks.append((check, args, kwargs))
yield add_check
if context.in_graph_mode():
all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
for (check, _, _), (args, kwargs) in zip(checks, all_values):
check(*args, **kwargs)
示例15: __init__
def __init__(self, root_checkpointable):
"""Configure saving.
Args:
root_checkpointable: The root of the object graph to save/restore. This
object and all of its dependencies are saved in the checkpoint. When
restoring, objects are matched and restored starting from this root.
"""
# Allow passing in a weak reference to avoid reference cycles when
# `Checkpointable` objects save themselves.
self._root_checkpointable_ref = root_checkpointable
if context.in_graph_mode():
self._file_prefix_placeholder = constant_op.constant("model")
else:
self._file_prefix_placeholder = None
# Op caching for save
self._object_graph_feed_tensor = None
self._last_save_object_graph = None
self._last_save_saver = None
# Op caching for restore
self._object_graph_restore_tensor = None
self._last_restore_object_graph = None
self._last_restore_checkpoint = None