本文整理汇总了Python中tensorflow.contrib.distribute.python.values.select_device函数的典型用法代码示例。如果您正苦于以下问题:Python select_device函数的具体用法?Python select_device怎么用?Python select_device使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了select_device函数的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: testWrapClass
def testWrapClass(self):
# Normally a mirrored value would be the same across devices, but
# for a test it is convenient to be able to tell the values apart.
result = values.regroup({_device_str(0): _nested_value("1"),
_device_str(1): _nested_value("2")},
values.Mirrored)
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
self._is_per_device(result[0], ["a1", "a2"], values.Mirrored)
self._is_per_device(result[2], ["h1", "h2"], values.Mirrored)
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored)
self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored)
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
# Also test that we can undo the merge using select_device()
self.assertEqual(_nested_value("1"),
values.select_device(_device_str(0), result))
self.assertEqual(_nested_value("2"),
values.select_device(_device_str(1), result))
# Values are marked as mirrored, so select_device_mirrored() is allowed.
self.assertEqual(_nested_value("1"),
values.select_device_mirrored(_device_str(0), result))
self.assertEqual(_nested_value("2"),
values.select_device_mirrored(_device_str(1), result))
示例2: testNested
def testNested(self):
result = values.regroup({_device_str(0): _nested_value("1"),
_device_str(1): _nested_value("2")})
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
self._is_per_device(result[0], ["a1", "a2"])
self._is_per_device(result[2], ["h1", "h2"])
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
self._is_per_device(result[1][0], ["b1", "b2"])
self._is_per_device(result[1][2], ["g1", "g2"])
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_device(result[1][1]["c"], ["d1", "d2"])
self._is_per_device(result[1][1]["e"], ["f1", "f2"])
# Also test that we can undo the merge using select_device()
self.assertEqual(_nested_value("1"),
values.select_device(_device_str(0), result))
self.assertEqual(_nested_value("2"),
values.select_device(_device_str(1), result))
# select_device_mirrored() should fail due to non-mirrored values
with self.assertRaises(TypeError):
values.select_device_mirrored(_device_str(0), result)
with self.assertRaises(TypeError):
values.select_device_mirrored(_device_str(1), result)
示例3: _test_iterator
def _test_iterator(self, input_fn, worker_device_pairs, expected_values,
sess=None):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
iterator = values.InputFunctionIterator(input_fn, worker_device_pairs)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(iterator.initialize())
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate([values.select_device(d, next_element) for d in devices])
# After re-initializing the iterator, should be able to iterate again.
evaluate(iterator.initialize())
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, computed_value)
示例4: _test_iterator
def _test_iterator(self, iterator, devices, expected_values):
next_element = iterator.get_next()
for device in devices:
v = values.select_device(device, next_element)
# The `v` here can be a tuple.
for element in nest.flatten(v):
self.assertTrue(element.device in device)
for expected_value in expected_values:
actual = self.evaluate(
[values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, actual)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate([values.select_device(d, next_element) for d in devices])
示例5: _test_iterator_no_prefetch
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
iterator = per_device_dataset.make_one_shot_iterator()
for expected_value in expected_values:
next_element = iterator.get_next()
actual = self.evaluate([
values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, actual)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
self.evaluate([
values.select_device(d, next_element) for d in devices])
示例6: testNamedTupleEstimatorSpec
def testNamedTupleEstimatorSpec(self):
with context.graph_mode(), ops.Graph().as_default():
created_estimator_specs = {}
to_regroup = {}
for device_id in range(3):
spec = model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.TRAIN,
loss=constant_op.constant(device_id / 2),
train_op=array_ops.identity(constant_op.constant(device_id)))
created_estimator_specs[device_id] = spec
to_regroup[_device_str(device_id)] = spec
merged_estimator_spec = values.regroup(to_regroup)
self.assertTrue(
isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
for device_id in range(3):
d = _device_str(device_id)
self.assertEquals(created_estimator_specs[device_id].loss,
merged_estimator_spec.loss.get(d))
self.assertEquals(created_estimator_specs[device_id].train_op,
merged_estimator_spec.train_op.get(d))
# Scaffold is populated by `EstimatorSpec.__new__`.
self.assertEquals(created_estimator_specs[device_id].scaffold,
merged_estimator_spec.scaffold.get(d))
# Also test that we can undo the merge using select_device()
self.assertEquals(created_estimator_specs[device_id],
values.select_device(_device_str(device_id),
merged_estimator_spec))
示例7: _test_iterator_with_prefetch
def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
iterator = per_device_dataset.make_initializable_iterator()
self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = self.evaluate(
[values.select_device(d, next_element) for d in devices])
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
self.evaluate([
values.select_device(d, next_element) for d in devices])
示例8: testSameId
def testSameId(self):
foo = object()
result = values.regroup({_device_str(0): ("a", foo),
_device_str(1): ("b", foo)})
self.assertIsInstance(result, tuple)
self.assertEqual(2, len(result))
self._is_per_device(result[0], ["a", "b"])
self.assertIs(foo, result[1])
# Test select_device(), should undo the merge done by regroup().
result_0 = values.select_device(_device_str(0), result)
self.assertIsInstance(result_0, tuple)
self.assertEqual(2, len(result_0))
self.assertEqual("a", result_0[0])
self.assertIs(foo, result_0[1])
result_1 = values.select_device(_device_str(1), result)
self.assertIsInstance(result_1, tuple)
self.assertEqual(2, len(result_1))
self.assertEqual("b", result_1[0])
self.assertIs(foo, result_1[1])
示例9: _call_and_check
def _call_and_check(self, model_fn, inputs, expected_result, defuns,
two_variables=False):
cpu_dev = device_util.canonicalize("CPU:0")
gpu_dev = device_util.canonicalize("GPU:0")
devices = [cpu_dev, gpu_dev]
dist = mirrored_strategy.MirroredStrategy(devices)
with dist.scope():
mock_model = MockModel(two_variables)
self.evaluate(variables.global_variables_initializer())
result = dist.call_for_each_tower(model_fn, mock_model, *inputs,
run_concurrently=False)
for device in devices:
device_result = values.select_device(device, result)
device_expected_result = values.select_device(device, expected_result)
self.assertAllClose(device_expected_result,
self.evaluate(device_result))
for defun in defuns:
self.assertEqual(set(mock_model.variables), set(defun.variables))
示例10: _test_iterator_with_prefetch
def _test_iterator_with_prefetch(self, devices, dataset, expected_values):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
iterator = per_device_dataset.make_one_shot_iterator()
# With prefetching, we cannot guarantee which input ends up on which
# device, so we verify that the complete set seen on all devices is
# correct, and equal numbers are distributed to each device.
combined_actual = []
combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
combined_actual.extend(self.evaluate([
values.select_device(d, next_element) for d in devices]))
combined_expected.extend(expected_value)
self.assertEqual(set(combined_expected), set(combined_actual))
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
self.evaluate([
values.select_device(d, next_element) for d in devices])
示例11: testOneDevice
def testOneDevice(self):
result = values.regroup({_device_str(0): _nested_value("1")})
# On one device regroup() and select_device() are basically identity.
self.assertEqual(_nested_value("1"), result)
self.assertEqual(_nested_value("1"),
values.select_device(_device_str(0), result))
# The one exception has to do with MirroredVariables.
d = "/device:CPU:0"
with ops.device(d):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
index = {d: v}
mirrored = values.MirroredVariable(index, v)
result = values.regroup(index)
self.assertIs(mirrored, result)
示例12: _call_for_each_tower
#.........这里部分代码省略.........
`True`.
Returns:
Merged return value of `fn` across all towers.
Raises:
RuntimeError: If fn() calls get_tower_context().merge_call() a different
number of times from the available devices.
"""
run_concurrently = kwargs.pop("run_concurrently", True)
if not context.executing_eagerly():
# Lots of TF library code isn't thread-safe in graph mode, and
# there is little to be gained by turning on multithreading when
# constructing a graph.
run_concurrently = False
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
elif run_concurrently is None:
run_concurrently = True
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
shared_variable_store = {}
# TODO(isaprykin): Create these threads once instead of during every run()
# call.
threads = []
for index, d in enumerate(distribution.worker_devices):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access
distribution, coord, d, variable_creator_fn, fn,
*values.select_device(d, args), **values.select_device(d, kwargs))
threads.append(t)
for t in threads:
t.start()
# When `fn` starts `should_run` event is set on _MirroredTowerThread
# (`MTT`) threads. The execution waits until
# `MTT.has_paused` is set, which indicates that either `fn` is
# complete or a `get_tower_context().merge_call()` is called. If `fn` is
# complete, then `MTT.done` is set to True. Otherwise, arguments
# of `get_tower_context().merge_call` from all paused threads are grouped
# and the `merge_fn` is performed. Results of the
# `get_tower_context().merge_call` are then set to `MTT.merge_result`.
# Each such `get_tower_context().merge_call` call returns the
# `MTT.merge_result` for that thread when `MTT.should_run` event
# is reset again. Execution of `fn` resumes.
try:
with coord.stop_on_exception():
all_done = False
while not all_done and not coord.should_stop():
done = []
if run_concurrently:
for t in threads:
t.should_run.set()
for t in threads:
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
else:
for t in threads:
t.should_run.set()
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
if coord.should_stop():
return None
all_done = all(done)
if not all_done:
if any(done):
raise RuntimeError("Some towers made a different number of "
"tower_context().merge_call() calls.")
# get_tower_context().merge_call() case
merge_args = values.regroup({t.device: t.merge_args for t in threads})
merge_kwargs = values.regroup(
{t.device: t.merge_kwargs for t in threads})
# We capture the name_scope of the MTT when we call merge_fn
# to ensure that if we have opened a name scope in the MTT,
# it will be respected when executing the merge function. We only
# capture the name_scope from the first MTT and assume it is
# the same for all other MTTs.
mtt_captured_name_scope = threads[0].captured_name_scope
with ops.name_scope(mtt_captured_name_scope):
merge_result = threads[0].merge_fn(distribution, *merge_args,
**merge_kwargs)
for t in threads:
t.merge_result = values.select_device(t.device, merge_result)
finally:
for t in threads:
t.should_run.set()
coord.join(threads)
return values.regroup({t.device: t.main_result for t in threads})