本文整理汇总了Python中tensorflow.python.distribute.distribution_strategy_context.get_replica_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_replica_context函数的具体用法?Python get_replica_context怎么用?Python get_replica_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_replica_context函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: set_non_tensor_output
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
示例2: _assert_in_default_state
def _assert_in_default_state(t):
t.assertIs(ds_context._get_default_replica_context(),
ds_context.get_replica_context())
t.assertIs(None, ds_context.get_cross_replica_context())
t.assertFalse(ds_context.in_cross_replica_context())
t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
t.assertFalse(ds_context.has_strategy())
示例3: decorated
def decorated(_, *args):
"""Decorated function with merge_call."""
replica_context = distribution_strategy_context.get_replica_context()
if replica_context is None: # if in cross replica context already
result_t = array_ops.identity(result_fn(*args))
else:
# TODO(psv): Test distribution of metrics using different distribution
# strategies.
# Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
# with distribution object as the first parameter. We create a wrapper
# here so that the result function need not have that parameter.
def merge_fn_wrapper(distribution, merge_fn, *args):
# We will get `PerReplica` merge function. Taking the first one as all
# are identical copies of the function that we had passed below.
merged_result_fn = (
distribution.experimental_local_results(merge_fn)[0](*args))
# Wrapping result in identity so that control dependency between
# update_op from `update_state` and result works in case result returns
# a tensor.
return array_ops.identity(merged_result_fn)
# Wrapping result in merge_call. merge_call is used when we want to leave
# replica mode and compute a value in cross replica mode.
result_t = replica_context.merge_call(
merge_fn_wrapper, args=(result_fn,) + args)
return result_t
示例4: apply_gradients
def apply_gradients(self, grads_and_vars, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
grads_and_vars = _filter_grads(grads_and_vars)
var_list = [v for (_, v) in grads_and_vars]
self._create_hypers()
with ops.init_scope():
self._create_slots(var_list)
self._prepare(var_list)
return distribute_ctx.get_replica_context().merge_call(
self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
示例5: merge_fn
def merge_fn(dist, s):
self.assertIs(ds_context._get_default_strategy(), dist)
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertIs(dist, ds_context.get_strategy())
self.assertFalse(ds_context.has_strategy())
return "foo_" + s
示例6: merge_grads
def merge_grads(grads_and_vars):
"""Merge gradients from different replicas."""
def merge_grad_fn(strategy, grads_and_vars):
reduced_grads = strategy.extended.batch_reduce_to(
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
return reduced_grads
return distribute_ctx.get_replica_context().merge_call(
merge_grad_fn, args=(grads_and_vars,))
示例7: set_last_step_output
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
axis=None)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
axis=None)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
示例8: run_fn
def run_fn():
replica_context = ds_context.get_replica_context()
self.assertTrue(replica_context is not None)
self.assertIs(None, ds_context.get_cross_replica_context())
self.assertFalse(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())
self.assertIs(dist, ds_context.get_strategy())
self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="bar"))
示例9: _test_run
def _test_run(self, strategy):
out1 = strategy.experimental_run_v2(
lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1)
self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))
out2 = strategy.experimental_run_v2(
lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
self.assertAllEqual([2, 4], out2_vals["a"])
self.assertAllEqual([1, 4], out2_vals["b"])
out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2)
self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
示例10: merge_update_step
def merge_update_step(update_ops, local_step):
"""Merge local step counter update from different replicas."""
def merge_update_step_fn(strategy, update_ops, local_step):
merged_ops = []
for update_op in update_ops:
merged_ops.append(strategy.group(update_op))
with ops.control_dependencies(merged_ops):
incre_op = local_step.assign_add(1).op
return incre_op
return distribute_ctx.get_replica_context().merge_call(
merge_update_step_fn, args=(update_ops, local_step))
示例11: testScope
def testScope(self):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())
self.assertIs(dist, ds_context.get_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self)
示例12: testMergeCall
def testMergeCall(self):
_assert_in_default_state(self)
def merge_fn(dist, s):
self.assertIs(ds_context._get_default_strategy(), dist)
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertIs(dist, ds_context.get_strategy())
self.assertFalse(ds_context.has_strategy())
return "foo_" + s
replica_ctx = ds_context.get_replica_context()
self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
_assert_in_default_state(self)
示例13: _test_step_fn
def _test_step_fn(inputs):
"""A fn that returns output of single test step."""
if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
inputs, targets = inputs
else:
targets = None
(distribution_strategy_context.get_replica_context().merge_call(
_build_model, args=(model, mode, inputs, targets)))
(_, outputs, updates, _) = (
_per_replica_execution_function(
distributed_training_utils.get_distributed_model(model, mode),
mode))
with ops.control_dependencies([updates]):
return outputs
示例14: skip_summary
def skip_summary():
"""Determines if summary should be skipped.
If using multiple replicas in distributed strategy, skip summaries on all
replicas except the first one (replica_id=0).
Returns:
True if the summary is skipped; False otherwise.
"""
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last replica,
# compute sum or mean across replicas).
replica_context = distribution_strategy_context.get_replica_context()
if not replica_context:
return False
# TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
# initialized, remember to change here as well.
replica_id = replica_context.replica_id_in_sync_group
if isinstance(replica_id, ops.Tensor):
replica_id = tensor_util.constant_value(replica_id)
return replica_id and replica_id > 0
示例15: decorated
def decorated(_, *args):
"""Decorated function with merge_call."""
replica_context = distribution_strategy_context.get_replica_context()
if replica_context is None: # if in cross replica context already
result_t = result_fn(*args)
else:
# TODO(psv): Test distribution of metrics using different distribution
# strategies.
# Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
# with distribution object as the first parameter. We create a wrapper
# here so that the result function need not have that parameter.
def merge_fn_wrapper(distribution, merge_fn, *args):
# We will get `PerDevice` merge function. Taking the first one as all
# are identical copies of the function that we had passed below.
return distribution.unwrap(merge_fn)[0](*args)
# Wrapping result in merge_call. merge_call is used when we want to leave
# replica mode and compute a value in cross replica mode.
result_t = replica_context.merge_call(
merge_fn_wrapper, args=(result_fn,) + args)
return result_t