本文整理汇总了Python中tensorflow.python.distribute.distribute_coordinator_context.get_current_worker_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_current_worker_context函数的具体用法?Python get_current_worker_context怎么用?Python get_current_worker_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_current_worker_context函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _worker_fn
def _worker_fn(strategy):
"""Function for worker task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
context = dc_context.get_current_worker_context()
_init_run_config_from_worker_context(local_estimator._config, context)
logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
# In the standalone client, we don't need to run hooks on all threads
# because logging hooks on all threads may be too much on the screen; also
# tensor passed to one hook can only be fetched with the graph where the
# tensor is defined. Other hooks such as checkpointing hooks will added by
# MonitoredTrainingSession.
# TODO(yuefengz): Is there a hook that does need to run on all threads in
# standalone client mode?
if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access
dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
hooks = list(train_spec.hooks)
else:
hooks = []
# Prevent estimator.train from calling distribute coordinator again. This
# function calls estimator.train which will use distribute coordinator path
# again if `_distribute_coordinator_mode` is set.
local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=hooks)
示例2: configure_and_create_session
def configure_and_create_session(distribution_strategy):
"""Configure session config and create a session with it."""
# TODO(priyag): Throw error if a session already exists.
session_config = K.get_default_session_config()
if is_tpu_strategy(distribution_strategy):
# TODO(priyag, yuefengz): Remove this workaround when Distribute
# Coordinator is integrated with keras and we can create a session from
# there.
distribution_strategy.configure(session_config)
master = distribution_strategy.extended._tpu_cluster_resolver.master() # pylint: disable=protected-access
session = session_module.Session(config=session_config, target=master)
else:
worker_context = dc_context.get_current_worker_context()
if worker_context:
dc_session_config = worker_context.session_config
# Merge the default session config to the one from distribute coordinator,
# which is fine for now since they don't have conflicting configurations.
dc_session_config.MergeFrom(session_config)
session = session_module.Session(
config=dc_session_config, target=worker_context.master_target)
else:
session = session_module.Session(config=session_config)
K.set_session(session)
示例3: filter_distributed_callbacks
def filter_distributed_callbacks(callbacks_list):
"""Filter Callbacks based on the worker context when running multi-worker.
Arguments:
callbacks_list: A list of `Callback` instances.
Returns:
The list of `Callback` instances that should be run on this worker.
"""
if not K.in_multi_worker_mode():
raise ValueError(
'filter_distributed_callbacks() should only be called when Keras '
'is in multi worker mode.')
worker_context = dc_context.get_current_worker_context()
callbacks_list = callbacks_list or []
if not [
c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
]:
# TODO(rchao): Consider providing a ModelCheckpoint here if the user
# fails to.
logging.warning('ModelCheckpoint callback is not provided. '
'Workers will need to restart training if any fails.')
# TODO(rchao): Add similar warning for restoring callback (to be designed).
if callbacks_list is None or worker_context.is_chief:
return callbacks_list
# Some Callbacks should only run on the chief worker.
return [
callback for callback in callbacks_list if not callback._chief_worker_only
] # pylint: disable=protected-access
示例4: _between_graph_with_monitored_session
def _between_graph_with_monitored_session(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with ops.device("/job:ps/task:0"):
# TODO(yuefengz): investigate why not using resource variable will make
# the test flaky.
x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True)
with ops.device("/job:ps/task:1"):
y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True)
x_add = x.assign_add(2.0)
y_sub = y.assign_sub(2.0)
train_op = control_flow_ops.group([x_add, y_sub])
# The monitored session will run init or ready ops.
with monitored_session.MonitoredSession() as sess:
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
if context.has_barrier:
context.wait_for_other_workers()
else:
self._barrier.wait()
x_val, y_val = sess.run([x, y])
self.assertEqual(x_val, 16.0)
self.assertEqual(y_val, 14.0)
if x_val == 16.0 and y_val == 14.0:
with self._lock:
self._result_correct += 1
示例5: __enter__
def __enter__(self):
old_context = distribute_coordinator_context.get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.\t" +
self._debug_message())
# pylint: disable=protected-access
distribute_coordinator_context._worker_context.current = self
示例6: _worker_fn
def _worker_fn(self, strategy):
worker_context = distribute_coordinator_context.get_current_worker_context()
session_config = worker_context._session_config
self._device_filters.extend(session_config.device_filters)
self._intra_op_parallelism_threads = (
session_config.intra_op_parallelism_threads)
self._inter_op_parallelism_threads = (
session_config.inter_op_parallelism_threads)
return MockServer()
示例7: init_restore_or_wait_for_variables
def init_restore_or_wait_for_variables():
"""Initialize or restore variables or wait for variables to be initialized."""
session = K._get_session() # pylint: disable=protected-access
worker_context = dc_context.get_current_worker_context()
if not worker_context or worker_context.experimental_should_init:
# TODO(yuefengz): if checkpoints exist, restore from checkpoint.
K._initialize_variables(session) # pylint: disable=protected-access
else:
_wait_for_variable_initialization(session)
示例8: _eval_fn
def _eval_fn(strategy):
"""Function for evaluator task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
local_estimator._eval_distribution = strategy
executor = executor_cls(local_estimator, train_spec, eval_spec)
executor._start_continuous_evaluation()
示例9: _worker_fn
def _worker_fn(strategy):
"""Function for worker task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks))
示例10: worker_fn
def worker_fn(strategy):
with ops.Graph().as_default():
batch_size = 64
steps = 2
with strategy.scope():
train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
model = _clone_and_build_model(orig_model, strategy)
orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)
# Workaround for the metrics issue (b/122928955) in async training. This
# can only be used in standalone client mode.
dc_context.get_current_worker_context().wait_for_other_workers()
model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)
dc_context.get_current_worker_context().wait_for_other_workers()
trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)
test_obj.assertLessEqual(trained_loss, orig_loss)
test_obj.assertGreaterEqual(trained_acc, orig_acc)
示例11: _between_graph_worker_fn
def _between_graph_worker_fn(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
# TODO(yuefengz): investigate why not using resource variable will make
# the test flaky.
x = variable_scope.get_variable(
"x", initializer=10.0, use_resource=True)
with ops.device("/job:ps/task:1"):
y = variable_scope.get_variable(
"y", initializer=20.0, use_resource=True)
x_add = x.assign_add(2.0)
y_sub = y.assign_sub(2.0)
train_op = control_flow_ops.group([x_add, y_sub])
if context.is_chief:
self.evaluate(variables.global_variables_initializer())
# Synchronize workers after initializaton.
if context.has_barrier:
context.wait_for_other_workers()
else:
while True:
uninit_vars = sess.run(variables.report_uninitialized_variables())
# pylint: disable=g-explicit-length-test
if len(uninit_vars) == 0:
break
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
if context.has_barrier:
context.wait_for_other_workers()
else:
self._barrier.wait()
x_val, y_val = sess.run([x, y])
self.assertEqual(x_val, 16.0)
self.assertEqual(y_val, 14.0)
if x_val == 16.0 and y_val == 14.0:
with self._lock:
self._result_correct += 1
示例12: _eval_fn
def _eval_fn(strategy):
"""Function for evaluator task."""
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._eval_distribution = strategy
# Prevent estimator.evaluate from calling distribute coordinator again. This
# function calls estimator.evaluate which will use distribute coordinator
# path again if `_distribute_coordinator_mode` is set.
local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access
executor = executor_cls(local_estimator, train_spec, eval_spec)
executor._start_continuous_evaluation()
示例13: _dump_strategy_property
def _dump_strategy_property(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
self.assertEqual(context._strategy.should_init, strategy.should_init)
self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
self.assertEqual(context.should_save_summary, strategy.should_save_summary)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
if task_type not in self._strategy_property:
self._strategy_property[task_type] = []
while len(self._strategy_property[task_type]) <= task_id:
self._strategy_property[task_type].append(None)
self._strategy_property[task_type][task_id] = (
context._strategy.should_init, context.should_checkpoint,
context.should_save_summary)
示例14: _in_graph_worker_fn
def _in_graph_worker_fn(self, strategy):
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
expected = 0.0
for i in range(context.num_workers):
with ops.device("/job:worker/task:%d" % i):
x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
x_add = x.assign_add(float(i))
xs.append(x_add)
expected += i + 10.0
with ops.device("/job:worker/task:0"):
result = math_ops.add_n(xs)
self.evaluate(variables.global_variables_initializer())
result_value = sess.run(result)
self.assertEqual(result_value, expected)
if result_value == expected:
self._result_correct += 1
示例15: _dump_worker_context
def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
Args:
strategy: a `DistributionStrategy` object.
"""
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
if task_type not in self._worker_context:
self._worker_context[task_type] = []
while len(self._worker_context[task_type]) <= task_id:
self._worker_context[task_type].append(None)
self._worker_context[task_type][task_id] = (context.master_target,
context.num_workers,
context.is_chief,
context.distributed_mode)