当前位置: 首页>>代码示例>>Python>>正文


Python distribution_strategy_context.get_replica_context函数代码示例

本文整理汇总了Python中tensorflow.python.distribute.distribution_strategy_context.get_replica_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_replica_context函数的具体用法?Python get_replica_context怎么用?Python get_replica_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_replica_context函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: set_non_tensor_output

 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribution_strategy_context.in_cross_replica_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as reduction doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribution_strategy_context.get_replica_context().merge_call(
         merge_fn, args=(output,))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:11,代码来源:input_lib.py

示例2: _assert_in_default_state

def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:distribute_lib_test.py

示例3: decorated

  def decorated(_, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = array_ops.identity(result_fn(*args))
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerReplica` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        merged_result_fn = (
            distribution.experimental_local_results(merge_fn)[0](*args))

        # Wrapping result in identity so that control dependency between
        # update_op from `update_state` and result works in case result returns
        # a tensor.
        return array_ops.identity(merged_result_fn)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)
    return result_t
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:metrics_utils.py

示例4: apply_gradients

  def apply_gradients(self, grads_and_vars, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

    self._create_hypers()
    with ops.init_scope():
      self._create_slots(var_list)

    self._prepare(var_list)

    return distribute_ctx.get_replica_context().merge_call(
        self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:30,代码来源:optimizer_v2.py

示例5: merge_fn

 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:distribute_lib_test.py

示例6: merge_grads

def merge_grads(grads_and_vars):
  """Merge gradients from different replicas."""

  def merge_grad_fn(strategy, grads_and_vars):
    reduced_grads = strategy.extended.batch_reduce_to(
        ds_reduce_util.ReduceOp.SUM, grads_and_vars)
    return reduced_grads

  return distribute_ctx.get_replica_context().merge_call(
      merge_grad_fn, args=(grads_and_vars,))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:optimizer_v2.py

示例7: set_last_step_output

  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribution_strategy_context.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
                                                            axis=None)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
                                                            axis=None)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribution_strategy_context.get_replica_context().merge_call(
          merge_fn, args=(output,))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:40,代码来源:input_lib.py

示例8: run_fn

 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py

示例9: _test_run

  def _test_run(self, strategy):
    out1 = strategy.experimental_run_v2(
        lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1)
    self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))

    out2 = strategy.experimental_run_v2(
        lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
    out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
    self.assertAllEqual([2, 4], out2_vals["a"])
    self.assertAllEqual([1, 4], out2_vals["b"])

    out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2)
    self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:13,代码来源:strategy_test_lib.py

示例10: merge_update_step

def merge_update_step(update_ops, local_step):
  """Merge local step counter update from different replicas."""

  def merge_update_step_fn(strategy, update_ops, local_step):
    merged_ops = []
    for update_op in update_ops:
      merged_ops.append(strategy.group(update_op))
    with ops.control_dependencies(merged_ops):
      incre_op = local_step.assign_add(1).op
    return incre_op

  return distribute_ctx.get_replica_context().merge_call(
      merge_update_step_fn, args=(update_ops, local_step))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:13,代码来源:optimizer_v2.py

示例11: testScope

 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, ds_context.get_replica_context())
     self.assertIs(dist, ds_context.get_cross_replica_context())
     self.assertTrue(ds_context.in_cross_replica_context())
     self.assertTrue(ds_context.has_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:distribute_lib_test.py

示例12: testMergeCall

  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(ds_context._get_default_strategy(), dist)
      self.assertIs(None, ds_context.get_replica_context())
      self.assertIs(dist, ds_context.get_cross_replica_context())
      self.assertTrue(ds_context.in_cross_replica_context())
      self.assertIs(dist, ds_context.get_strategy())
      self.assertFalse(ds_context.has_strategy())
      return "foo_" + s

    replica_ctx = ds_context.get_replica_context()
    self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
    self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
    _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:distribute_lib_test.py

示例13: _test_step_fn

  def _test_step_fn(inputs):
    """A fn that returns output of single test step."""
    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
      inputs, targets = inputs
    else:
      targets = None

    (distribution_strategy_context.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs, targets)))

    (_, outputs, updates, _) = (
        _per_replica_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
    with ops.control_dependencies([updates]):
      return outputs
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:training_distributed.py

示例14: skip_summary

def skip_summary():
  """Determines if summary should be skipped.

  If using multiple replicas in distributed strategy, skip summaries on all
  replicas except the first one (replica_id=0).

  Returns:
    True if the summary is skipped; False otherwise.
  """

  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last replica,
  # compute sum or mean across replicas).
  replica_context = distribution_strategy_context.get_replica_context()
  if not replica_context:
    return False
  # TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
  # initialized, remember to change here as well.
  replica_id = replica_context.replica_id_in_sync_group
  if isinstance(replica_id, ops.Tensor):
    replica_id = tensor_util.constant_value(replica_id)
  return replica_id and replica_id > 0
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:22,代码来源:summary_op_util.py

示例15: decorated

  def decorated(_, *args):
    """Decorated function with merge_call."""
    replica_context = distribution_strategy_context.get_replica_context()
    if replica_context is None:  # if in cross replica context already
      result_t = result_fn(*args)
    else:
      # TODO(psv): Test distribution of metrics using different distribution
      # strategies.

      # Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
      # with distribution object as the first parameter. We create a wrapper
      # here so that the result function need not have that parameter.
      def merge_fn_wrapper(distribution, merge_fn, *args):
        # We will get `PerDevice` merge function. Taking the first one as all
        # are identical copies of the function that we had passed below.
        return distribution.unwrap(merge_fn)[0](*args)

      # Wrapping result in merge_call. merge_call is used when we want to leave
      # replica mode and compute a value in cross replica mode.
      result_t = replica_context.merge_call(
          merge_fn_wrapper, args=(result_fn,) + args)
    return result_t
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:22,代码来源:metrics_utils.py


注:本文中的tensorflow.python.distribute.distribution_strategy_context.get_replica_context函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。