当前位置: 首页>>代码示例>>Python>>正文


Python distribute.get_tower_context函数代码示例

本文整理汇总了Python中tensorflow.python.training.distribute.get_tower_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_tower_context函数的具体用法?Python get_tower_context怎么用?Python get_tower_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。


在下文中一共展示了get_tower_context函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: model_fn

 def model_fn():
   vs = []
   vs.append(variable_scope.variable(1.0, name="foo/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
   vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return vs
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:8,代码来源:mirrored_strategy_multigpu_test.py

示例2: set_non_tensor_output

 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribute_lib.get_cross_tower_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as aggregation doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribute_lib.get_tower_context().merge_call(merge_fn, output)
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:10,代码来源:values.py

示例3: model_fn

    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:12,代码来源:mirrored_strategy_test.py

示例4: model_fn

    def model_fn():
      v0 = variable_scope.get_variable("var0", [1])
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.get_variable("var1", [1])
        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
        v2 = variable_scope.get_variable(
            "var2", [1],
            synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=variable_scope.VariableAggregation.SUM)
        v3 = variable_scope.get_variable(
            "var3", [1],
            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
            aggregation=variable_scope.VariableAggregation.MEAN)

      return v0, v1, v2, v3
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:16,代码来源:mirrored_strategy_multigpu_test.py

示例5: _assert_in_default_state

def _assert_in_default_state(t):
  t.assertIs(distribute._default_tower_context,
             distribute.get_tower_context())
  t.assertIs(None, distribute.get_cross_tower_context())
  t.assertIs(distribute._default_distribution_strategy,
             distribute.get_distribution_strategy())
  t.assertFalse(distribute.has_distribution_strategy())
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:7,代码来源:distribute_test.py

示例6: merge_fn

 def merge_fn(dist, s):
   self.assertIs(distribute._default_distribution_strategy, dist)
   self.assertIs(None, distribute.get_tower_context())
   self.assertIs(dist, distribute.get_cross_tower_context())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertFalse(distribute.has_distribution_strategy())
   return "foo_" + s
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:7,代码来源:distribute_test.py

示例7: _assign_func

  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribute_lib.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      # We are calling update on the mirrored variable in cross tower context.
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in cross
        # tower context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      return distribute_lib.get_distribution_strategy().update(
          self, f, *args, **kwargs)
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
                                                           **kwargs)
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:34,代码来源:values.py

示例8: skip_summary

def skip_summary():
  # If using multiple towers in distributed strategy, skip summaries on all
  # towers except the first one (tower_id=0).
  # TODO(priyag): Add a new optional argument that will provide multiple
  # alternatives to override default behavior. (e.g. run on last tower,
  # compute sum or mean across towers).
  tower_context = distribute.get_tower_context()
  return tower_context and tower_context.tower_id > 0
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:8,代码来源:summary_op_util.py

示例9: run_fn

 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:8,代码来源:distribute_test.py

示例10: testScope

 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:10,代码来源:distribute_test.py

示例11: run_fn

 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:12,代码来源:distribute_test.py

示例12: set_last_step_output

  def set_last_step_output(self, name, output,
                           aggregation=variables_lib.VariableAggregation.NONE):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
    if distribute_lib.get_cross_tower_context():
      self._last_step_outputs_aggregations[name] = aggregation
      if aggregation is variables_lib.VariableAggregation.NONE:
        self._last_step_outputs[name] = output
      else:
        distribution = distribute_lib.get_distribution_strategy()
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, output, destinations="/device:CPU:0")
    else:
      assert aggregation is not variables_lib.VariableAggregation.NONE
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, value, destinations="/device:CPU:0")
        # Setting this inside the `merge_fn` because all towers share the same
        # context object, so it's more robust to set it only once (even if all
        # the towers are trying to set the same value).
        self._last_step_outputs_aggregations[name] = aggregation
      distribute_lib.get_tower_context().merge_call(merge_fn, output)
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:40,代码来源:values.py

示例13: testScope

 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:14,代码来源:distribute_test.py

示例14: testMergeCall

  def testMergeCall(self):
    _assert_in_default_state(self)

    def merge_fn(dist, s):
      self.assertIs(distribute._default_distribution_strategy, dist)
      self.assertIs(None, distribute.get_tower_context())
      self.assertIs(dist, distribute.get_cross_tower_context())
      self.assertIs(dist, distribute.get_distribution_strategy())
      self.assertFalse(distribute.has_distribution_strategy())
      return "foo_" + s

    tower_ctx = distribute.get_tower_context()
    self.assertIs(distribute._default_tower_context, tower_ctx)
    self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
    _assert_in_default_state(self)
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:15,代码来源:distribute_test.py

示例15: get

 def get(self, device=None):
   """Returns the value for the current device or raises a ValueError."""
   if device is None:
     tower_context = distribute_lib.get_tower_context()
     if tower_context:
       device = tower_context.device
     else:
       device = distribute_lib.get_update_device()
       if device is None:
         device = device_util.current()
   device = device_util.canonicalize(device)
   try:
     return self._index[device]
   except KeyError:
     raise ValueError("Device %s not found in %s (current device %s)" %
                      (device, self._index.keys(), device_util.current()))
开发者ID:bikong2,项目名称:tensorflow,代码行数:16,代码来源:values.py


注:本文中的tensorflow.python.training.distribute.get_tower_context函数示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。