当前位置: 首页>>代码示例>>Python>>正文


Python run_config.RunConfig方法代码示例

本文整理汇总了Python中tensorflow.contrib.learn.python.learn.estimators.run_config.RunConfig方法的典型用法代码示例。如果您正苦于以下问题:Python run_config.RunConfig方法的具体用法?Python run_config.RunConfig怎么用?Python run_config.RunConfig使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow.contrib.learn.python.learn.estimators.run_config的用法示例。


在下文中一共展示了run_config.RunConfig方法的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _get_default_schedule

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _get_default_schedule(config):
  """Returns the default schedule for the provided RunConfig."""
  if not config or not _is_distributed(config):
    return 'train_and_evaluate'

  if not config.task_type:
    raise ValueError('Must specify a schedule')

  if config.task_type == run_config_lib.TaskType.MASTER:
    # TODO(rhaertel): handle the case where there is more than one master
    # or explicitly disallow such a case.
    return 'train_and_evaluate'
  elif config.task_type == run_config_lib.TaskType.PS:
    return 'run_std_server'
  elif config.task_type == run_config_lib.TaskType.WORKER:
    return 'train'

  raise ValueError('No default schedule for task type: %s' % (config.task_type)) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:20,代码来源:learn_runner.py

示例2: _get_replica_device_setter

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _get_replica_device_setter(config):
  """Creates a replica device setter if required.

  Args:
    config: A RunConfig instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableOfTensors', 'MutableDenseHashTable'
  ]

  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return device_setter.replica_device_setter(
        ps_tasks=config.num_ps_replicas, worker_device=worker_device,
        merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
  else:
    return None 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:27,代码来源:estimator.py

示例3: _get_replica_device_setter

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _get_replica_device_setter(config):
  """Creates a replica device setter if required.

  Args:
    config: A RunConfig instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableOfTensors', 'MutableDenseHashTable'
  ]

  if config.job_name:
    worker_device = '/job:%s/task:%d' % (config.job_name, config.task)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return device_setter.replica_device_setter(
        ps_tasks=config.num_ps_replicas, worker_device=worker_device,
        merge_devices=False, ps_ops=ps_ops, cluster=config.cluster_spec)
  else:
    return None 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:27,代码来源:estimator.py

示例4: config

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def config(self):
    # TODO(wicke): make RunConfig immutable, and then return it without a copy.
    return copy.deepcopy(self._config) 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:5,代码来源:estimator.py

示例5: __init__

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def __init__(self, model_dir=None, config=None):
    """Initializes a BaseEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      config: A RunConfig instance.
    """
    # Model directory.
    self._model_dir = model_dir
    if self._model_dir is None:
      self._model_dir = tempfile.mkdtemp()
      logging.warning('Using temporary folder as model directory: %s',
                      self._model_dir)

    # Create a run configuration.
    if config is None:
      self._config = BaseEstimator._Config()
      logging.info('Using default config.')
    else:
      self._config = config
    logging.info('Using config: %s', str(vars(self._config)))

    # Set device function depending if there are replicas or not.
    self._device_fn = _get_replica_device_setter(self._config)

    # Features and labels TensorSignature objects.
    # TODO(wicke): Rename these to something more descriptive
    self._features_info = None
    self._labels_info = None

    self._graph = None 
开发者ID:abhisuri97,项目名称:auto-alt-text-lambda-api,代码行数:35,代码来源:estimator.py

示例6: _create_experiment_fn

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _create_experiment_fn(output_dir):  # pylint: disable=unused-argument
  """Experiment creation function."""
  (columns, label_column, wide_columns, deep_columns, categorical_columns,
   continuous_columns) = census_model_config()

  census_data_source = CensusDataSource(FLAGS.data_dir,
                                        TRAIN_DATA_URL, TEST_DATA_URL,
                                        columns, label_column,
                                        categorical_columns,
                                        continuous_columns)

  config = run_config.RunConfig(master=FLAGS.master_grpc_url,
                                num_ps_replicas=FLAGS.num_parameter_servers,
                                task=FLAGS.worker_index)

  estimator = tf.contrib.learn.DNNLinearCombinedClassifier(
      model_dir=FLAGS.model_dir,
      linear_feature_columns=wide_columns,
      dnn_feature_columns=deep_columns,
      dnn_hidden_units=[5],
      config=config)

  return tf.contrib.learn.Experiment(
      estimator=estimator,
      train_input_fn=census_data_source.input_train_fn,
      eval_input_fn=census_data_source.input_test_fn,
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps
  ) 
开发者ID:tobegit3hub,项目名称:deep_image_model,代码行数:31,代码来源:census_widendeep.py

示例7: _wrapped_experiment_fn_with_uid_check

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False):
  """Wraps the `RunConfig` uid check with `experiment_fn`.

  For `experiment_fn` which takes `run_config`, it is expected that the
  `run_config` is passed to the Estimator correctly. Toward that, the wrapped
  `experiment_fn` compares the `uid` of the `RunConfig` instance.

  Args:
    experiment_fn: The original `experiment_fn` which takes `run_config` and
      `hparams`.
    require_hparams: If True, the `hparams` passed to `experiment_fn` cannot be
      `None`.

  Returns:
    A experiment_fn with same signature.
  """
  def wrapped_experiment_fn(run_config, hparams):
    """Calls experiment_fn and checks the uid of `RunConfig`."""
    if not isinstance(run_config, run_config_lib.RunConfig):
      raise ValueError('`run_config` must be `RunConfig` instance')
    if not run_config.model_dir:
      raise ValueError(
          'Must specify a model directory `model_dir` in `run_config`.')
    if hparams is not None and not isinstance(hparams, hparam_lib.HParams):
      raise ValueError('`hparams` must be `HParams` instance')
    if require_hparams and hparams is None:
      raise ValueError('`hparams` cannot be `None`.')

    expected_uid = run_config.uid()
    experiment = experiment_fn(run_config, hparams)

    if not isinstance(experiment, Experiment):
      raise TypeError('Experiment builder did not return an Experiment '
                      'instance, got %s instead.' % type(experiment))

    if experiment.estimator.config.uid() != expected_uid:
      raise RuntimeError(
          '`RunConfig` instance is expected to be used by the `Estimator` '
          'inside the `Experiment`. expected {}, but got {}'.format(
              expected_uid, experiment.estimator.config.uid()))
    return experiment
  return wrapped_experiment_fn 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:44,代码来源:learn_runner.py

示例8: __init__

# 需要导入模块: from tensorflow.contrib.learn.python.learn.estimators import run_config [as 别名]
# 或者: from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig [as 别名]
def __init__(self, model_dir=None, config=None):
    """Initializes a BaseEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `None`, the model_dir in
        `config` will be used if set. If both are set, they must be same.
      config: A RunConfig instance.
    """
    # Create a run configuration.
    if config is None:
      self._config = BaseEstimator._Config()
      logging.info('Using default config.')
    else:
      self._config = config

    if self._config.session_config is None:
      self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    else:
      self._session_config = self._config.session_config

    # Model directory.
    if (model_dir is not None) and (self._config.model_dir is not None):
      if model_dir != self._config.model_dir:
        # TODO(b/9965722): remove this suppression after it is no longer
        #                  necessary.
        # pylint: disable=g-doc-exception
        raise ValueError(
            "model_dir are set both in constructor and RunConfig, but with "
            "different values. In constructor: '{}', in RunConfig: "
            "'{}' ".format(model_dir, self._config.model_dir))

    self._model_dir = model_dir or self._config.model_dir
    if self._model_dir is None:
      self._model_dir = tempfile.mkdtemp()
      logging.warning('Using temporary folder as model directory: %s',
                      self._model_dir)
    if self._config.model_dir is None:
      self._config = self._config.replace(model_dir=self._model_dir)
    logging.info('Using config: %s', str(vars(self._config)))

    # Set device function depending if there are replicas or not.
    self._device_fn = _get_replica_device_setter(self._config)

    # Features and labels TensorSignature objects.
    # TODO(wicke): Rename these to something more descriptive
    self._features_info = None
    self._labels_info = None

    self._graph = None 
开发者ID:ryfeus,项目名称:lambda-packs,代码行数:53,代码来源:estimator.py


注:本文中的tensorflow.contrib.learn.python.learn.estimators.run_config.RunConfig方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。