当前位置: 首页>>代码示例>>Python>>正文


Python distribution_strategy_context.get_strategy函数代码示例

本文整理汇总了Python中tensorflow.python.distribute.distribution_strategy_context.get_strategy函数的典型用法代码示例。如果您正苦于以下问题:Python get_strategy函数的具体用法?Python get_strategy怎么用?Python get_strategy使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_strategy函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: call_replica_local_fn

def call_replica_local_fn(fn, *args, **kwargs):
  """Call a function that uses replica-local variables.

  This function correctly handles calling `fn` in a cross-replica
  context.

  Arguments:
    fn: The function to call.
    *args: Positional arguments to the `fn`.
    **kwargs: Keyword argument to `fn`.

  Returns:
    The result of calling `fn`.
  """
  # TODO(b/120571621): We want to avoid reductions here since
  # since TPUStrategy does not implement replica local variables.
  # Remove this hack once we support TPUReplicaLocalVariables.
  strategy = None
  if 'strategy' in kwargs:
    strategy = kwargs.pop('strategy')
  else:
    if ds_context.get_strategy():
      strategy = ds_context.get_strategy()

  is_tpu = is_tpu_strategy(strategy)
  if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
    with strategy.scope():
      return strategy.extended.call_for_each_replica(fn, args, kwargs)
  return fn(*args, **kwargs)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:distributed_training_utils.py

示例2: add_slot

 def add_slot(self, var, slot_name, initializer="zeros"):
   """Add a new slot variable for `var`."""
   if slot_name not in self._slot_names:
     self._slot_names.append(slot_name)
   var_key = _var_key(var)
   slot_dict = self._slots.setdefault(var_key, {})
   weight = slot_dict.get(slot_name, None)
   if weight is None:
     if isinstance(initializer, six.string_types) or callable(initializer):
       initializer = initializers.get(initializer)
       initial_value = functools.partial(
           initializer, shape=var.shape, dtype=var.dtype)
     else:
       initial_value = initializer
     strategy = distribute_ctx.get_strategy()
     with strategy.extended.colocate_vars_with(var):
       weight = tf_variables.Variable(
           name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
           dtype=var.dtype,
           trainable=False,
           initial_value=initial_value)
     backend.track_variable(weight)
     slot_dict[slot_name] = weight
     self._restore_slot_variable(
         slot_name=slot_name, variable=var,
         slot_variable=weight)
     self._weights.append(weight)
   return weight
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:optimizer_v2.py

示例3: create_slot

def create_slot(primary, val, name, colocate_with_primary=True):
  """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
  # Scope the slot name in the namespace of the primary variable.
  # Set "primary.op.name + '/' + name" as default name, so the scope name of
  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
  # and the same name has been previously used, the scope name will add '_N'
  # as suffix for unique identifications.
  validate_shape = val.get_shape().is_fully_defined()
  if context.executing_eagerly():
    prefix = primary._shared_name  # pylint: disable=protected-access
  else:
    prefix = primary.op.name
  with variable_scope.variable_scope(None, prefix + "/" + name):
    if colocate_with_primary:
      distribution_strategy = distribution_strategy_context.get_strategy()
      with distribution_strategy.colocate_vars_with(primary):
        return _create_slot_var(primary, val, "", validate_shape, None, None)
    else:
      return _create_slot_var(primary, val, "", validate_shape, None, None)
开发者ID:pyjennings,项目名称:tensorflow,代码行数:32,代码来源:slot_creator.py

示例4: _fused_batch_norm

  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()
        mean_update = strategy.extended.update(
            self.moving_mean, self._assign_moving_average,
            (mean, self.momentum))
        variance_update = strategy.extended.update(
            self.moving_variance, self._assign_moving_average,
            (variance, self.momentum))
      else:
        mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                  momentum)
        variance_update = self._assign_moving_average(self.moving_variance,
                                                      variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
开发者ID:gautam1858,项目名称:tensorflow,代码行数:60,代码来源:normalization.py

示例5: scale_loss_for_distribution

def scale_loss_for_distribution(loss_value):
  """Scales and returns the given loss value by the number of replicas."""
  num_replicas = (
      distribution_strategy_context.get_strategy().num_replicas_in_sync)
  if num_replicas > 1:
    loss_value *= (1. / num_replicas)
  return loss_value
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:losses_utils.py

示例6: _create_non_slot_variable

  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_trackable()
      distribution_strategy = distribute_ctx.get_strategy()
      with distribution_strategy.extended.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(
            initial_value, name=name, trainable=False,
            use_resource=resource_variable_ops.is_resource_variable(
                colocate_with))
      # Restore this variable by name if necessary, but don't add a
      # Trackable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, trackable=v)
      self._non_slot_dict[key] = v

    return v
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:optimizer.py

示例7: _assert_in_default_state

def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:distribute_lib_test.py

示例8: merge_fn

 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:distribute_lib_test.py

示例9: _scale_loss

 def _scale_loss(loss_value):
   ops.get_default_graph()._is_loss_scaled_by_optimizer = False  # pylint: disable=protected-access
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
       ops.get_default_graph()._is_loss_scaled_by_optimizer = True  # pylint: disable=protected-access
   return loss_value
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:optimizer.py

示例10: testSetStrategy

 def testSetStrategy(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   dist2 = _TestStrategy()
   ds_context.experimental_set_strategy(dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   expected_value = _get_test_variable(
       "baz", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="baz"))
   ds_context.experimental_set_strategy(dist2)
   self.assertIs(dist2, ds_context.get_strategy())
   ds_context.experimental_set_strategy(None)
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:19,代码来源:distribute_lib_test.py

示例11: _reduce_weighted_loss

def _reduce_weighted_loss(
    weighted_losses, reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE):
  """Reduces the individual weighted loss measurements."""
  if reduction == losses_impl.ReductionV2.NONE:
    loss = weighted_losses
  else:
    loss = math_ops.reduce_sum(weighted_losses)
    if reduction == losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE:
      num_replicas = (  # Used to convert from local to global batch size.
          distribution_strategy_context.get_strategy().num_replicas_in_sync)
      loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses))
  return loss
开发者ID:ziky90,项目名称:tensorflow,代码行数:12,代码来源:losses_utils.py

示例12: _get_tensor

  def _get_tensor(self, is_finite):
    tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))

    if not distribution_strategy_context.has_strategy():
      return tensor
    def get():
      rep_id = (distribution_strategy_context.get_replica_context()
                .replica_id_in_sync_group)
      return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor,
                                   lambda: 1.)
    distribution = distribution_strategy_context.get_strategy()
    return distribution.extended.call_for_each_replica(get)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:12,代码来源:loss_scale_test.py

示例13: run_fn

 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py

示例14: testScopeDeviceNestingError

 def testScopeDeviceNestingError(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   # Open a device scope with dist.scope().
   dist.extended._default_device = "/device:GPU:0"
   scope = dist.scope()
   scope.__enter__()
   self.assertIs(dist, ds_context.get_strategy())
   with ops.device("/device:CPU:0"):
     with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
       scope.__exit__(None, None, None)
   scope.__exit__(None, None, None)
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py

示例15: testScopeVarScopeNestingError

 def testScopeVarScopeNestingError(self):
   # We create a new graph here to simplify clean-up, since the error
   # we are triggering happens in the middle of scope.__exit__() and
   # leaves us in a weird state.
   with ops.Graph().as_default():
     _assert_in_default_state(self)
     dist = _TestStrategy()
     scope = dist.scope()
     scope.__enter__()
     self.assertIs(dist, ds_context.get_strategy())
     with variable_scope.variable_scope("AA"):
       with self.assertRaisesRegexp(RuntimeError,
                                    "Variable scope nesting error"):
         scope.__exit__(None, None, None)
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:distribute_lib_test.py


注:本文中的tensorflow.python.distribute.distribution_strategy_context.get_strategy函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。