本文整理汇总了Python中tensorflow.contrib.learn.python.learn.estimators.run_config.RunConfig方法的典型用法代码示例。如果您正苦于以下问题:Python run_config.RunConfig方法的具体用法?Python run_config.RunConfig怎么用?Python run_config.RunConfig使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.contrib.learn.python.learn.estimators.run_config
的用法示例。
在下文中一共展示了run_config.RunConfig方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _get_default_schedule
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _get_default_schedule(config):
"""Returns the default schedule for the provided RunConfig."""
if not config or not _is_distributed(config):
return 'train_and_evaluate'
if not config.task_type:
raise ValueError('Must specify a schedule')
if config.task_type == run_config_lib.TaskType.MASTER:
# TODO(rhaertel): handle the case where there is more than one master
# or explicitly disallow such a case.
return 'train_and_evaluate'
elif config.task_type == run_config_lib.TaskType.PS:
return 'run_std_server'
elif config.task_type == run_config_lib.TaskType.WORKER:
return 'train'
raise ValueError('No default schedule for task type: %s' % (config.task_type))
示例2: _get_replica_device_setter
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _get_replica_device_setter(config):
"""Creates a replica device setter if required.
Args:
config: A RunConfig instance.
Returns:
A replica device setter, or None.
"""
ps_ops = [
'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
'MutableHashTableOfTensors', 'MutableDenseHashTable'
]
if config.task_type:
worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
else:
worker_device = '/job:worker'
if config.num_ps_replicas > 0:
return device_setter.replica_device_setter(
ps_tasks=config.num_ps_replicas, worker_device=worker_device,
merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
else:
return None
示例3: _get_replica_device_setter
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _get_replica_device_setter(config):
"""Creates a replica device setter if required.
Args:
config: A RunConfig instance.
Returns:
A replica device setter, or None.
"""
ps_ops = [
'Variable', 'AutoReloadVariable', 'MutableHashTable',
'MutableHashTableOfTensors', 'MutableDenseHashTable'
]
if config.job_name:
worker_device = '/job:%s/task:%d' % (config.job_name, config.task)
else:
worker_device = '/job:worker'
if config.num_ps_replicas > 0:
return device_setter.replica_device_setter(
ps_tasks=config.num_ps_replicas, worker_device=worker_device,
merge_devices=False, ps_ops=ps_ops, cluster=config.cluster_spec)
else:
return None
示例4: config
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def config(self):
# TODO(wicke): make RunConfig immutable, and then return it without a copy.
return copy.deepcopy(self._config)
示例5: __init__
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def __init__(self, model_dir=None, config=None):
"""Initializes a BaseEstimator instance.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
config: A RunConfig instance.
"""
# Model directory.
self._model_dir = model_dir
if self._model_dir is None:
self._model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s',
self._model_dir)
# Create a run configuration.
if config is None:
self._config = BaseEstimator._Config()
logging.info('Using default config.')
else:
self._config = config
logging.info('Using config: %s', str(vars(self._config)))
# Set device function depending if there are replicas or not.
self._device_fn = _get_replica_device_setter(self._config)
# Features and labels TensorSignature objects.
# TODO(wicke): Rename these to something more descriptive
self._features_info = None
self._labels_info = None
self._graph = None
示例6: _create_experiment_fn
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _create_experiment_fn(output_dir): # pylint: disable=unused-argument
"""Experiment creation function."""
(columns, label_column, wide_columns, deep_columns, categorical_columns,
continuous_columns) = census_model_config()
census_data_source = CensusDataSource(FLAGS.data_dir,
TRAIN_DATA_URL, TEST_DATA_URL,
columns, label_column,
categorical_columns,
continuous_columns)
config = run_config.RunConfig(master=FLAGS.master_grpc_url,
num_ps_replicas=FLAGS.num_parameter_servers,
task=FLAGS.worker_index)
estimator = tf.contrib.learn.DNNLinearCombinedClassifier(
model_dir=FLAGS.model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=[5],
config=config)
return tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=census_data_source.input_train_fn,
eval_input_fn=census_data_source.input_test_fn,
train_steps=FLAGS.train_steps,
eval_steps=FLAGS.eval_steps
)
示例7: _wrapped_experiment_fn_with_uid_check
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False):
"""Wraps the `RunConfig` uid check with `experiment_fn`.
For `experiment_fn` which takes `run_config`, it is expected that the
`run_config` is passed to the Estimator correctly. Toward that, the wrapped
`experiment_fn` compares the `uid` of the `RunConfig` instance.
Args:
experiment_fn: The original `experiment_fn` which takes `run_config` and
`hparams`.
require_hparams: If True, the `hparams` passed to `experiment_fn` cannot be
`None`.
Returns:
A experiment_fn with same signature.
"""
def wrapped_experiment_fn(run_config, hparams):
"""Calls experiment_fn and checks the uid of `RunConfig`."""
if not isinstance(run_config, run_config_lib.RunConfig):
raise ValueError('`run_config` must be `RunConfig` instance')
if not run_config.model_dir:
raise ValueError(
'Must specify a model directory `model_dir` in `run_config`.')
if hparams is not None and not isinstance(hparams, hparam_lib.HParams):
raise ValueError('`hparams` must be `HParams` instance')
if require_hparams and hparams is None:
raise ValueError('`hparams` cannot be `None`.')
expected_uid = run_config.uid()
experiment = experiment_fn(run_config, hparams)
if not isinstance(experiment, Experiment):
raise TypeError('Experiment builder did not return an Experiment '
'instance, got %s instead.' % type(experiment))
if experiment.estimator.config.uid() != expected_uid:
raise RuntimeError(
'`RunConfig` instance is expected to be used by the `Estimator` '
'inside the `Experiment`. expected {}, but got {}'.format(
expected_uid, experiment.estimator.config.uid()))
return experiment
return wrapped_experiment_fn
示例8: __init__
# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def __init__(self, model_dir=None, config=None):
"""Initializes a BaseEstimator instance.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same.
config: A RunConfig instance.
"""
# Create a run configuration.
if config is None:
self._config = BaseEstimator._Config()
logging.info('Using default config.')
else:
self._config = config
if self._config.session_config is None:
self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
else:
self._session_config = self._config.session_config
# Model directory.
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
# TODO(b/9965722): remove this suppression after it is no longer
# necessary.
# pylint: disable=g-doc-exception
raise ValueError(
"model_dir are set both in constructor and RunConfig, but with "
"different values. In constructor: '{}', in RunConfig: "
"'{}' ".format(model_dir, self._config.model_dir))
self._model_dir = model_dir or self._config.model_dir
if self._model_dir is None:
self._model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s',
self._model_dir)
if self._config.model_dir is None:
self._config = self._config.replace(model_dir=self._model_dir)
logging.info('Using config: %s', str(vars(self._config)))
# Set device function depending if there are replicas or not.
self._device_fn = _get_replica_device_setter(self._config)
# Features and labels TensorSignature objects.
# TODO(wicke): Rename these to something more descriptive
self._features_info = None
self._labels_info = None
self._graph = None