当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.HParams方法代码示例

本文整理汇总了Python中tensorflow.HParams方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.HParams方法的具体用法?Python tensorflow.HParams怎么用?Python tensorflow.HParams使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.HParams方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: create_hparams

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def create_hparams(hparams_overrides=None):
  """Returns hyperparameters, including any flag value overrides.

  Args:
    hparams_overrides: Optional hparams overrides, represented as a
      string containing comma-separated hparam_name=value pairs.

  Returns:
    The hyperparameters as a tf.HParams object.
  """
  hparams = tf.contrib.training.HParams(
      # Whether a fine tuning checkpoint (provided in the pipeline config)
      # should be loaded for training.
      load_pretrained=True)
  # Override any of the preceding hyperparameter values.
  if hparams_overrides:
    hparams = hparams.parse(hparams_overrides)
  return hparams 
开发者ID:ahmetozlu,项目名称:vehicle_counting_tensorflow,代码行数:20,代码来源:model_hparams.py

示例2: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:25,代码来源:cifar10_main.py

示例3: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def main(output_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  # UAI SDK use --output_dir as model_dir
  # UAI SDK use --data_dir as data_dir
  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=output_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(
          is_chief=config.is_chief,
          **hparams)) 
开发者ID:ucloud,项目名称:uai-sdk,代码行数:27,代码来源:cifar10_main.py

示例4: merge_hparams

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def merge_hparams(hparams_1, hparams_2):
  """Merge hyperparameters from two tf.HParams objects.

  If the same key is present in both HParams objects, the value from `hparams_2`
  will be used.

  Args:
    hparams_1: The first tf.HParams object to merge.
    hparams_2: The second tf.HParams object to merge.

  Returns:
    A merged tf.HParams object with the hyperparameters from both `hparams_1`
    and `hparams_2`.
  """
  hparams_map = hparams_1.values()
  hparams_map.update(hparams_2.values())
  return tf.contrib.training.HParams(**hparams_map) 
开发者ID:personads,项目名称:synvae,代码行数:19,代码来源:tf_utils.py

示例5: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def main(job_dir, data_dir, num_gpus, variable_strategy,
         use_distortion_for_training, log_device_placement, num_intra_threads,
         **hparams):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'
  os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

  # Session configuration.
  sess_config = tf.ConfigProto(
      allow_soft_placement=True,
      log_device_placement=log_device_placement,
      intra_op_parallelism_threads=num_intra_threads,
      gpu_options=tf.GPUOptions(force_gpu_compatible=True))

  config = cifar10_utils.RunConfig(
      session_config=sess_config, model_dir=job_dir)
  tf.contrib.learn.learn_runner.run(
      get_experiment_fn(data_dir, num_gpus, variable_strategy,
                        use_distortion_for_training),
      run_config=config,
      hparams=tf.contrib.training.HParams(**hparams)) 
开发者ID:sshleifer,项目名称:object_detection_kitti,代码行数:23,代码来源:cifar10_main.py

示例6: create_hparams

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def create_hparams(hparams_overrides=None):
    """Returns hyperparameters, including any flag value overrides.

    Args:
      hparams_overrides: Optional hparams overrides, represented as a
        string containing comma-separated hparam_name=value pairs.

    Returns:
      The hyperparameters as a tf.HParams object.
    """
    hparams = tf.contrib.training.HParams(
        # Whether a fine tuning checkpoint (provided in the pipeline config)
        # should be loaded for training.
        load_pretrained=True)
    # Override any of the preceding hyperparameter values.
    if hparams_overrides:
        hparams = hparams.parse(hparams_overrides)
    return hparams 
开发者ID:scorelab,项目名称:Elphas,代码行数:20,代码来源:model_hparams.py

示例7: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def __init__(self,
               task_config,
               model_hparams=None,
               embedder_hparams=None,
               train_hparams=None):
    """Constructs a policy which knows how to work with tasks (see tasks.py).

    It allows to read task history, goal and outputs in consistency with the
    task config.

    Args:
      task_config: an object of type tasks.TaskIOConfig (see tasks.py)
      model_hparams: a tf.HParams object containing parameter pertaining to
        model (these are implementation specific)
      embedder_hparams: a tf.HParams object containing parameter pertaining to
        history, goal embedders (these are implementation specific)
      train_hparams: a tf.HParams object containing parameter pertaining to
        trainin (these are implementation specific)`
    """
    super(TaskPolicy, self).__init__(None, None)
    self._model_hparams = model_hparams
    self._embedder_hparams = embedder_hparams
    self._train_hparams = train_hparams
    self._task_config = task_config
    self._extra_train_ops = [] 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:27,代码来源:policies.py

示例8: dataset

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def dataset(self, mode, hparams=None, global_step=None, **kwargs):
    """Returns a dataset containing examples from multiple problems.

    Args:
      mode: A member of problem.DatasetSplit.
      hparams: A tf.HParams object, the model hparams.
      global_step: A scalar tensor used to compute the sampling distribution.
        If global_step is None, we call tf.train.get_or_create_global_step by
        default.
      **kwargs: Keywords for problem.Problem.Dataset.

    Returns:
      A dataset containing examples from multiple problems.
    """
    datasets = [p.dataset(mode, **kwargs) for p in self.problems]
    datasets = [
        d.map(lambda x, i=j: self.normalize_example(  # pylint: disable=g-long-lambda
            dict(x, problem_id=tf.constant([i])), hparams))
        for j, d in enumerate(datasets)  # Tag examples with a problem_id.
    ]
    if mode is problem.DatasetSplit.TRAIN:
      if global_step is None:
        global_step = tf.train.get_or_create_global_step()
      pmf = get_schedule_distribution(self.schedule, global_step)
      return get_multi_dataset(datasets, pmf)
    elif self.only_eval_first_problem:
      return datasets[0]
    else:
      datasets = [d.repeat() for d in datasets]
      return tf.data.Dataset.zip(tuple(datasets)).flat_map(
          lambda *x: functools.reduce(  # pylint: disable=g-long-lambda
              tf.data.Dataset.concatenate,
              map(tf.data.Dataset.from_tensors, x))) 
开发者ID:yyht,项目名称:BERT,代码行数:35,代码来源:multi_problem_v2.py

示例9: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def __init__(self, hparams):
    """Constructor.

    Args:
      hparams: tf.HParams object.
    """
    self.hparams = hparams 
开发者ID:brain-research,项目名称:mpnn,代码行数:9,代码来源:mpnn.py

示例10: default_hparams

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def default_hparams():
    return tf.contrib.training.HParams(
        set2set_comps=12,
        non_edge=0,
        node_dim=50,
        num_propagation_steps=6,
        num_output_hidden_layers=1,
        max_grad_norm=4.0,
        batch_size=20,
        optimizer="adam",
        momentum=.9,  # only used if optimizer is set to momentum
        init_learning_rate=.00013,
        decay_factor=.5,  # final learning rate will be initial*.1
        decay_every=500000,  # how often to decay the lr (#batches)
        reuse=True,  # use the same message and update weights at each time step
        message_function="matrix_multiply",
        update_function="GRU",
        output_function="graph_level",
        hidden_dim=200,
        keep_prob=1.0,  # in our experiments dropout did not help
        edge_num_layers=4,
        edge_hidden_dim=50,
        propagation_type="normal",
        activation="relu",
        normalizer="none",
        inner_prod="default"  #inner product similarity to use for set2vec
    ) 
开发者ID:brain-research,项目名称:mpnn,代码行数:29,代码来源:mpnn.py

示例11: build_hparams

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def build_hparams(cell_name='amoeba_net_d'):
  """Build tf.Hparams for training Amoeba Net.

  Args:
    cell_name:         Which of the cells in model_specs.py to use to build the
                       amoebanet neural network; the cell names defined in that
                       module correspond to architectures discovered by an
                       evolutionary search described in
                       https://arxiv.org/abs/1802.01548.

  Returns:
    A set of tf.HParams suitable for Amoeba Net training.
  """
  hparams = imagenet_hparams()
  operations, hiddenstate_indices, used_hiddenstates = (
      model_specs.get_normal_cell(cell_name))
  hparams.add_hparam('normal_cell_operations', operations)
  hparams.add_hparam('normal_cell_hiddenstate_indices',
                     hiddenstate_indices)
  hparams.add_hparam('normal_cell_used_hiddenstates',
                     used_hiddenstates)
  operations, hiddenstate_indices, used_hiddenstates = (
      model_specs.get_reduction_cell(cell_name))
  hparams.add_hparam('reduction_cell_operations',
                     operations)
  hparams.add_hparam('reduction_cell_hiddenstate_indices',
                     hiddenstate_indices)
  hparams.add_hparam('reduction_cell_used_hiddenstates',
                     used_hiddenstates)
  hparams.set_hparam('data_format', 'NHWC')
  return hparams 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:33,代码来源:amoeba_net_model.py

示例12: formatted_hparams

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def formatted_hparams(hparams):
  """Formatts the hparams into a readable string.

  Also looks for attributes that have not correctly been added to the hparams
  and prints the keys as "bad keys". These bad keys may be left out of iterators
  and cirumvent type checking.

  Args:
    hparams: an HParams instance.

  Returns:
    A string.
  """
  # Look for bad keys (see docstring).
  good_keys = set(hparams.values().keys())
  bad_keys = []
  for key in hparams.__dict__:
    if key not in good_keys and not key.startswith('_'):
      bad_keys.append(key)
  bad_keys.sort()

  # Format hparams.
  readable_items = [
      '%s: %s' % (k, v) for k, v in sorted(hparams.values().iteritems())]
  readable_items.append('Bad keys: %s' % ','.join(bad_keys))
  readable_string = ('\n'.join(readable_items))
  return readable_string 
开发者ID:mlperf,项目名称:training_results_v0.5,代码行数:29,代码来源:amoeba_net_model.py


注:本文中的tensorflow.HParams方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。