本文整理汇总了Python中tensorflow.python.eager.tape.stop_recording函数的典型用法代码示例。如果您正苦于以下问题:Python stop_recording函数的具体用法?Python stop_recording怎么用?Python stop_recording使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了stop_recording函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _create_variable
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
collections = kwargs.pop("collections", None)
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
tower_local = kwargs.pop("tower_local_reduce_method", None)
if tower_local is not None:
kwargs["trainable"] = False
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
index = {}
for i, d in enumerate(devices):
with ops.device(d):
if i > 0:
# Give replicas meaningful distinct names:
var0name = index[devices[0]].name.split(":")[0]
kwargs["name"] = "%s/replica_%d" % (var0name, i)
# Initialize replicas with the same value:
if context.executing_eagerly():
initial_value = index[devices[0]].value()
else:
initial_value = index[devices[0]].initial_value
kwargs["initial_value"] = array_ops.identity(initial_value)
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
if tower_local is None:
result = values.MirroredVariable(index, index[devices[0]])
else:
result = values.TowerLocalVariable(
index, index[devices[0]], tower_local)
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
return result
示例2: _create_tpu_mirrored_variable
def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
strategy, device_map, logical_device, real_mirrored_creator,
*args, **kwargs):
# Figure out what collections this variable should be added to.
# We'll add the TPUMirroredVariable to those collections instead.
var_collections = kwargs.pop("collections", None)
if var_collections is None:
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
# TODO(jhseu): Should we have different behavior for different
# synchronization settings?
# Get aggregation value
# TODO(jhseu): Support aggregation in a replica context.
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
if aggregation not in [
vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
vs.VariableAggregation.MEAN,
vs.VariableAggregation.ONLY_FIRST_REPLICA,
]:
raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
.format(aggregation, kwargs["name"]))
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
devices = device_map.logical_to_actual_devices(logical_device)
value_list = real_mirrored_creator(devices, *args, **kwargs)
result = values.TPUMirroredVariable(
strategy, device_map, value_list, aggregation,
logical_device=logical_device)
if not (context.executing_eagerly() or ops.inside_function()):
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in value_list:
l.remove(v)
g.add_to_collections(var_collections, result)
return result
示例3: compute_gradients
def compute_gradients(model, images, labels, num_replicas=1):
with tf.GradientTape() as grad_tape:
logits = model(images, training=True)
loss = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
tf.contrib.summary.scalar(name='loss', tensor=loss)
if num_replicas != 1:
loss /= num_replicas
# TODO(b/110991947): We can mistakenly trace the gradient call in
# multi-threaded environment. Explicitly disable recording until
# this is fixed.
with tape.stop_recording():
grads = grad_tape.gradient(loss, model.variables)
return grads
示例4: decorated
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
if context.in_graph_mode():
if kwargs:
raise ValueError(
"custom_gradient in graph mode doesn't support keyword arguments.")
name = "CustomGradient-%s" % tf_ops.uid()
args = [tf_ops.convert_to_tensor(x) for x in args]
result, grad_fn = f(*args)
flat_result = nest.flatten(result)
all_tensors = flat_result + args
@tf_ops.RegisterGradient(name)
def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable
gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
# Need to return one value per input to the IdentityN, so pad the
# gradients of the inputs of the custom_gradient function with the
# gradients of the outputs as well.
return ([None] * len(flat_result)) + gradients
with tf_ops.get_default_graph().gradient_override_map(
{"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
input_tensors = [x for x in args
if isinstance(x, tf_ops.Tensor)]
with tape.stop_recording():
result, grad_fn = f(*args, **kwargs)
# TODO(apassos): naive uses of custom_gradient will not get the correct
# second derivative this way if they capture any output tensors. Change the
# signature of custom_gradient.
def actual_grad_fn(*outputs):
return grad_fn(*outputs)
flat_result = nest.flatten(result)
tape.record_operation(
f.__name__,
flat_result,
input_tensors,
[],
actual_grad_fn)
flat_result = list(flat_result)
return result
示例5: decorated
def decorated(*args, **kwargs):
"""Decorated function with custom gradient."""
input_tensors = [x for x in args
if isinstance(x, tf_ops.Tensor)]
with tape.stop_recording():
result, grad_fn = f(*args, **kwargs)
# TODO(apassos): naive uses of custom_gradient will not get the correct
# second derivative this way if they capture any output tensors. Change the
# signature of custom_gradient.
def actual_grad_fn(*outputs):
return grad_fn(*outputs)
flat_result = nest.flatten(result)
tape.record_operation(
flat_result,
input_tensors,
[],
actual_grad_fn)
flat_result = list(flat_result)
return result
示例6: _create_mirrored_variable
def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
collections = kwargs.pop("collections", None)
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
# Get synchronization value
synchronization = kwargs.get("synchronization",
variable_scope.VariableSynchronization.ON_WRITE)
if synchronization == variable_scope.VariableSynchronization.NONE:
raise ValueError("`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please"
" change the `synchronization` for variable: " +
kwargs["name"])
elif synchronization == variable_scope.VariableSynchronization.ON_READ:
# Variables that are to be synced on read are tower local.
is_tower_local = True
kwargs["trainable"] = False
elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
synchronization == variable_scope.VariableSynchronization.AUTO):
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
is_tower_local = False
else:
raise ValueError("Invalid variable synchronization mode: " +
synchronization + " for variable: " + kwargs["name"])
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
variable_scope.VariableAggregation.MEAN,
variable_scope.VariableAggregation.ONLY_FIRST_TOWER
):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
index = real_mirrored_creator(devices, *args, **kwargs)
if is_tower_local:
result = values.TowerLocalVariable(index, index[devices[0]], aggregation)
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches
# ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
return result
示例7: __call__
def __call__(self, *args, **kwds):
"""Calls the graph function."""
if self._created_variables:
# In this case we have created variables on the first call, so we run the
# defunned version which is guaranteed to never create variables.
return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
elif self._stateful_fn is not None:
# In this case we have not created variables on the first call. So we can
# run the first trace but we should fail if variables are created.
results = self._stateful_fn(*args, **kwds)
if self._created_variables:
raise ValueError("Creating variables on a non-first call to a function"
" decorated with tf.function.")
return results
# This is the first call of __call__, so we have to initialize.
self._initialize(args, kwds)
if self._lifted_all_initializers and self._lifted_placeholders:
with ops.init_scope():
handles, placeholders = zip(*self._lifted_placeholders)
if context.executing_eagerly():
lifted_fn = function_lib._EagerDefinedFunction( # pylint: disable=protected-access
"initializer" + str(ops.uid()),
self._lifted_initializer_graph,
placeholders, [], {})
with tape.stop_recording():
lifted_fn.call(context.context(), list(handles))
return self._stateless_fn(*args, **kwds)
canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)
if not self._created_variables:
# If we did not create any variables the trace we have is good enough.
return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
def fn_with_cond(*inner_args, **inner_kwds):
"""Conditionally runs initialization if it's needed."""
condition = True
for wr in self._created_variables:
variable = wr()
if variable is None:
raise ValueError(
"A tf.Variable created inside your tf.function has been"
" garbage-collected. Your code needs to keep Python references"
" to variables created inside `tf.function`s.\n"
"\n"
"A common way to raise this error is to create and return a"
" variable only referenced inside your function:\n"
"\n"
"@tf.function\n"
"def f():\n"
" v = tf.Variable(1.0)\n"
" return v\n"
"\n"
"v = f() # Crashes with this error message!\n"
"\n"
"The reason this crashes is that @tf.function annotated"
" function returns a **`tf.Tensor`** with the **value** of the"
" variable when the function is called rather than the"
" variable instance itself. As such there is no code holding a"
" reference to the `v` created inside the function and Python"
" garbage collects it.\n"
"\n"
"The simplest way to fix this issue is to create variables"
" outside the function and capture them:\n"
"\n"
"v = tf.Variable(1.0)\n"
"\n"
"@tf.function\n"
"def f():\n"
" return v\n"
"\n"
"f() # <tf.Tensor: ... numpy=1.>\n"
"v.assign_add(1.)\n"
"f() # <tf.Tensor: ... numpy=2.>")
condition = math_ops.logical_and(
condition, resource_variable_ops.var_is_initialized_op(
variable.handle))
# We want to call stateless_fn if possible because it avoids recomputing
# potentially expensive initializers.
return control_flow_ops.cond(
condition,
lambda: self._stateless_fn(*inner_args, **inner_kwds),
functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access
inner_args, inner_kwds))
return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
示例8: _real_mirrored_creator
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
unique_var_name = ops.get_default_graph().unique_name(
kwargs["name"], mark_as_used=False).rstrip("/")
# pylint: disable=protected-access
collective_instance_key = self._collective_keys.get_instance_key(
key_id=unique_var_name)
# Only the first device participles in the broadcast of initial values.
group_key = self._collective_keys.get_group_key([devices[0]])
group_size = self._num_workers
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
initial_value = kwargs["initial_value"]
if callable(initial_value):
initial_value_fn = initial_value
else:
initial_value_fn = lambda: initial_value
value_list = []
for i, d in enumerate(devices):
with ops.init_scope(), ops.device(d):
if i == 0:
# The initial value fn makes sure variables all initialized to
# same values. The first device of the chief worker will send their
# variable values to other workers.
def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
with ops.device(device):
initial_value = initial_value_fn()
assert not callable(initial_value)
initial_value = ops.convert_to_tensor(initial_value)
assert index == 0, index
if self._num_workers > 1:
if self._is_chief:
bcast_send = collective_ops.broadcast_send(
initial_value, initial_value.shape, initial_value.dtype,
group_size, group_key, collective_instance_key)
with ops.control_dependencies([bcast_send]):
return array_ops.identity(initial_value)
else:
return collective_ops.broadcast_recv(
initial_value.shape, initial_value.dtype, group_size,
group_key, collective_instance_key)
return initial_value
else:
# Give replicas meaningful distinct names:
var0name = value_list[0].name.split(":")[0]
# We append a / to variable names created on replicas with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Variables on non-first replica get initial values from the
# variables created on the first device of each worker.
def _overridden_initial_value_fn(device=d, index=i):
assert index > 0
with ops.device(device):
if context.executing_eagerly():
return array_ops.identity(value_list[0].value())
else:
return array_ops.identity(value_list[0].initial_value)
kwargs["initial_value"] = _overridden_initial_value_fn
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
# Don't record operations (e.g. other variable reads) during
# variable creation.
with tape.stop_recording():
v = next_creator(*args, **kwargs)
if i == 0:
actual_var_name = v.name.split(":")[0]
assert unique_var_name == actual_var_name, "%r vs %r" % (
unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
value_list.append(v)
return value_list
示例9: _create_variable
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
collections = kwargs.pop("collections", None)
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
kwargs["collections"] = []
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
# Get synchronization value
synchronization = kwargs.get(
"synchronization", variable_scope.VariableSynchronization.ON_WRITE)
if synchronization == variable_scope.VariableSynchronization.NONE:
raise ValueError("`NONE` variable synchronization mode is not "
"supported with `Mirrored` distribution strategy. Please"
" change the `synchronization` for variable: " +
kwargs["name"])
elif synchronization == variable_scope.VariableSynchronization.ON_READ:
# Variables that are to be synced on read are tower local.
is_tower_local = True
kwargs["trainable"] = False
elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
synchronization == variable_scope.VariableSynchronization.AUTO):
# `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
is_tower_local = False
else:
raise ValueError("Invalid variable synchronization mode: " +
synchronization + " for variable: " + kwargs["name"])
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
if aggregation not in [
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
variable_scope.VariableAggregation.MEAN
]:
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
# Ignore user-specified caching device, not needed for mirrored variables.
kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
index = {}
for i, d in enumerate(devices):
with ops.device(d):
if i > 0:
# Give replicas meaningful distinct names:
var0name = index[devices[0]].name.split(":")[0]
# We append a / to variable names created on towers with id > 0 to
# ensure that we ignore the name scope and instead use the given
# name as the absolute name of the variable.
kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
if context.executing_eagerly():
kwargs["initial_value"] = array_ops.identity(
index[devices[0]].value())
else:
def initial_value_fn(device=d):
with ops.device(device):
return array_ops.identity(index[devices[0]].initial_value)
kwargs["initial_value"] = initial_value_fn
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
if is_tower_local:
result = values.TowerLocalVariable(index, index[devices[0]],
aggregation)
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
# to the TRAINABLE_VARIABLES collection, so we manually remove
# them and replace with the MirroredVariable. We can't set
# "trainable" to False for next_creator() since that causes functions
# like implicit_gradients to skip those variables.
if kwargs.get("trainable", True):
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
return result