本文整理汇总了Python中tensorflow.python.training.distribute.get_tower_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_tower_context函数的具体用法?Python get_tower_context怎么用?Python get_tower_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_tower_context函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: model_fn
def model_fn():
vs = []
vs.append(variable_scope.variable(1.0, name="foo/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
distribute_lib.get_tower_context().merge_call(lambda _: _)
return vs
示例2: set_non_tensor_output
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribute_lib.get_cross_tower_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as aggregation doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
distribute_lib.get_tower_context().merge_call(merge_fn, output)
示例3: model_fn
def model_fn(device_id):
assert isinstance(device_id, int)
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
with variable_scope.variable_creator_scope(thread_creator_fn):
# Create a variable in this scope.
v = variable_scope.variable(1.0)
# This will pause the current thread, and execute the other thread.
distribute_lib.get_tower_context().merge_call(lambda _: _)
return v
示例4: model_fn
def model_fn():
v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
distribute_lib.get_tower_context().merge_call(lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
aggregation=variable_scope.VariableAggregation.SUM)
v3 = variable_scope.get_variable(
"var3", [1],
synchronization=variable_scope.VariableSynchronization.ON_WRITE,
aggregation=variable_scope.VariableAggregation.MEAN)
return v0, v1, v2, v3
示例5: _assert_in_default_state
def _assert_in_default_state(t):
t.assertIs(distribute._default_tower_context,
distribute.get_tower_context())
t.assertIs(None, distribute.get_cross_tower_context())
t.assertIs(distribute._default_distribution_strategy,
distribute.get_distribution_strategy())
t.assertFalse(distribute.has_distribution_strategy())
示例6: merge_fn
def merge_fn(dist, s):
self.assertIs(distribute._default_distribution_strategy, dist)
self.assertIs(None, distribute.get_tower_context())
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertFalse(distribute.has_distribution_strategy())
return "foo_" + s
示例7: _assign_func
def _assign_func(self, *args, **kwargs):
f = kwargs.pop("f")
if distribute_lib.get_cross_tower_context():
update_device = distribute_lib.get_update_device()
# We are calling update on the mirrored variable in cross tower context.
if update_device is not None:
# We are calling an assign function on the mirrored variable in cross
# tower context.
v = self.get(device=update_device)
return f(v, *args, **kwargs)
return distribute_lib.get_distribution_strategy().update(
self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
# context.
# We reduce the value we want to assign/add/sub. More details about how we
# handle the different use cases can be found in the _reduce method.
# We call the function on each of the mirrored variables with the reduced
# value.
if self._aggregation == vs.VariableAggregation.NONE:
raise ValueError("You must specify an aggregation method to update a "
"MirroredVariable in Tower Context.")
def merge_fn(strategy, value, *other_args, **other_kwargs):
return strategy.update(
self, f,
strategy.reduce(
aggregation=self._aggregation, value=value, destinations=self),
*other_args, **other_kwargs)
return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
**kwargs)
示例8: skip_summary
def skip_summary():
# If using multiple towers in distributed strategy, skip summaries on all
# towers except the first one (tower_id=0).
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last tower,
# compute sum or mean across towers).
tower_context = distribute.get_tower_context()
return tower_context and tower_context.tower_id > 0
示例9: run_fn
def run_fn():
tower_context = distribute.get_tower_context()
self.assertTrue(tower_context is not None)
self.assertIs(None, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
示例10: testScope
def testScope(self):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
self.assertIs(None, distribute.get_tower_context())
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self)
示例11: run_fn
def run_fn():
tower_context = distribute.get_tower_context()
self.assertTrue(tower_context is not None)
self.assertIs(None, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="bar"))
示例12: set_last_step_output
def set_last_step_output(self, name, output,
aggregation=variables_lib.VariableAggregation.NONE):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
aggregation: Aggregation method to use to aggregate outputs from multiple
towers. Required if `set_last_step_output` is called in a tower context.
Optional in cross_tower_context.
When present, the outputs from all the towers are aggregated using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and aggregation is set, output
must be a `PerDevice` value.
The aggregation method is also recorded in a dictionary
`_last_step_outputs_aggregations` for later interpreting of the
outputs as already reduced or not.
"""
if distribute_lib.get_cross_tower_context():
self._last_step_outputs_aggregations[name] = aggregation
if aggregation is variables_lib.VariableAggregation.NONE:
self._last_step_outputs[name] = output
else:
distribution = distribute_lib.get_distribution_strategy()
self._last_step_outputs[name] = distribution.reduce(
aggregation, output, destinations="/device:CPU:0")
else:
assert aggregation is not variables_lib.VariableAggregation.NONE
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(
aggregation, value, destinations="/device:CPU:0")
# Setting this inside the `merge_fn` because all towers share the same
# context object, so it's more robust to set it only once (even if all
# the towers are trying to set the same value).
self._last_step_outputs_aggregations[name] = aggregation
distribute_lib.get_tower_context().merge_call(merge_fn, output)
示例13: testScope
def testScope(self):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
self.assertIs(None, distribute.get_tower_context())
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self)
示例14: testMergeCall
def testMergeCall(self):
_assert_in_default_state(self)
def merge_fn(dist, s):
self.assertIs(distribute._default_distribution_strategy, dist)
self.assertIs(None, distribute.get_tower_context())
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertFalse(distribute.has_distribution_strategy())
return "foo_" + s
tower_ctx = distribute.get_tower_context()
self.assertIs(distribute._default_tower_context, tower_ctx)
self.assertEqual("foo_bar", tower_ctx.merge_call(merge_fn, "bar"))
_assert_in_default_state(self)
示例15: get
def get(self, device=None):
"""Returns the value for the current device or raises a ValueError."""
if device is None:
tower_context = distribute_lib.get_tower_context()
if tower_context:
device = tower_context.device
else:
device = distribute_lib.get_update_device()
if device is None:
device = device_util.current()
device = device_util.canonicalize(device)
try:
return self._index[device]
except KeyError:
raise ValueError("Device %s not found in %s (current device %s)" %
(device, self._index.keys(), device_util.current()))