本文整理汇总了Python中tensorflow.HParams方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.HParams方法的具体用法?Python tensorflow.HParams怎么用?Python tensorflow.HParams使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow
的用法示例。
在下文中一共展示了tensorflow.HParams方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: create_hparams
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def create_hparams(hparams_overrides=None):
"""Returns hyperparameters, including any flag value overrides.
Args:
hparams_overrides: Optional hparams overrides, represented as a
string containing comma-separated hparam_name=value pairs.
Returns:
The hyperparameters as a tf.HParams object.
"""
hparams = tf.contrib.training.HParams(
# Whether a fine tuning checkpoint (provided in the pipeline config)
# should be loaded for training.
load_pretrained=True)
# Override any of the preceding hyperparameter values.
if hparams_overrides:
hparams = hparams.parse(hparams_overrides)
return hparams
示例2: main
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def main(job_dir, data_dir, num_gpus, variable_strategy,
use_distortion_for_training, log_device_placement, num_intra_threads,
**hparams):
# The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0'
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Session configuration.
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions(force_gpu_compatible=True))
config = cifar10_utils.RunConfig(
session_config=sess_config, model_dir=job_dir)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training),
run_config=config,
hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
示例3: main
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def main(output_dir, data_dir, num_gpus, variable_strategy,
use_distortion_for_training, log_device_placement, num_intra_threads,
**hparams):
# The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0'
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Session configuration.
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions(force_gpu_compatible=True))
# UAI SDK use --output_dir as model_dir
# UAI SDK use --data_dir as data_dir
config = cifar10_utils.RunConfig(
session_config=sess_config, model_dir=output_dir)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training),
run_config=config,
hparams=tf.contrib.training.HParams(
is_chief=config.is_chief,
**hparams))
示例4: merge_hparams
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def merge_hparams(hparams_1, hparams_2):
"""Merge hyperparameters from two tf.HParams objects.
If the same key is present in both HParams objects, the value from `hparams_2`
will be used.
Args:
hparams_1: The first tf.HParams object to merge.
hparams_2: The second tf.HParams object to merge.
Returns:
A merged tf.HParams object with the hyperparameters from both `hparams_1`
and `hparams_2`.
"""
hparams_map = hparams_1.values()
hparams_map.update(hparams_2.values())
return tf.contrib.training.HParams(**hparams_map)
示例5: main
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def main(job_dir, data_dir, num_gpus, variable_strategy,
use_distortion_for_training, log_device_placement, num_intra_threads,
**hparams):
# The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0'
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Session configuration.
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions(force_gpu_compatible=True))
config = cifar10_utils.RunConfig(
session_config=sess_config, model_dir=job_dir)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(data_dir, num_gpus, variable_strategy,
use_distortion_for_training),
run_config=config,
hparams=tf.contrib.training.HParams(**hparams))
示例6: create_hparams
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def create_hparams(hparams_overrides=None):
"""Returns hyperparameters, including any flag value overrides.
Args:
hparams_overrides: Optional hparams overrides, represented as a
string containing comma-separated hparam_name=value pairs.
Returns:
The hyperparameters as a tf.HParams object.
"""
hparams = tf.contrib.training.HParams(
# Whether a fine tuning checkpoint (provided in the pipeline config)
# should be loaded for training.
load_pretrained=True)
# Override any of the preceding hyperparameter values.
if hparams_overrides:
hparams = hparams.parse(hparams_overrides)
return hparams
示例7: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def __init__(self,
task_config,
model_hparams=None,
embedder_hparams=None,
train_hparams=None):
"""Constructs a policy which knows how to work with tasks (see tasks.py).
It allows to read task history, goal and outputs in consistency with the
task config.
Args:
task_config: an object of type tasks.TaskIOConfig (see tasks.py)
model_hparams: a tf.HParams object containing parameter pertaining to
model (these are implementation specific)
embedder_hparams: a tf.HParams object containing parameter pertaining to
history, goal embedders (these are implementation specific)
train_hparams: a tf.HParams object containing parameter pertaining to
trainin (these are implementation specific)`
"""
super(TaskPolicy, self).__init__(None, None)
self._model_hparams = model_hparams
self._embedder_hparams = embedder_hparams
self._train_hparams = train_hparams
self._task_config = task_config
self._extra_train_ops = []
示例8: dataset
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def dataset(self, mode, hparams=None, global_step=None, **kwargs):
"""Returns a dataset containing examples from multiple problems.
Args:
mode: A member of problem.DatasetSplit.
hparams: A tf.HParams object, the model hparams.
global_step: A scalar tensor used to compute the sampling distribution.
If global_step is None, we call tf.train.get_or_create_global_step by
default.
**kwargs: Keywords for problem.Problem.Dataset.
Returns:
A dataset containing examples from multiple problems.
"""
datasets = [p.dataset(mode, **kwargs) for p in self.problems]
datasets = [
d.map(lambda x, i=j: self.normalize_example( # pylint: disable=g-long-lambda
dict(x, problem_id=tf.constant([i])), hparams))
for j, d in enumerate(datasets) # Tag examples with a problem_id.
]
if mode is problem.DatasetSplit.TRAIN:
if global_step is None:
global_step = tf.train.get_or_create_global_step()
pmf = get_schedule_distribution(self.schedule, global_step)
return get_multi_dataset(datasets, pmf)
elif self.only_eval_first_problem:
return datasets[0]
else:
datasets = [d.repeat() for d in datasets]
return tf.data.Dataset.zip(tuple(datasets)).flat_map(
lambda *x: functools.reduce( # pylint: disable=g-long-lambda
tf.data.Dataset.concatenate,
map(tf.data.Dataset.from_tensors, x)))
示例9: __init__
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def __init__(self, hparams):
"""Constructor.
Args:
hparams: tf.HParams object.
"""
self.hparams = hparams
示例10: default_hparams
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def default_hparams():
return tf.contrib.training.HParams(
set2set_comps=12,
non_edge=0,
node_dim=50,
num_propagation_steps=6,
num_output_hidden_layers=1,
max_grad_norm=4.0,
batch_size=20,
optimizer="adam",
momentum=.9, # only used if optimizer is set to momentum
init_learning_rate=.00013,
decay_factor=.5, # final learning rate will be initial*.1
decay_every=500000, # how often to decay the lr (#batches)
reuse=True, # use the same message and update weights at each time step
message_function="matrix_multiply",
update_function="GRU",
output_function="graph_level",
hidden_dim=200,
keep_prob=1.0, # in our experiments dropout did not help
edge_num_layers=4,
edge_hidden_dim=50,
propagation_type="normal",
activation="relu",
normalizer="none",
inner_prod="default" #inner product similarity to use for set2vec
)
示例11: build_hparams
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def build_hparams(cell_name='amoeba_net_d'):
"""Build tf.Hparams for training Amoeba Net.
Args:
cell_name: Which of the cells in model_specs.py to use to build the
amoebanet neural network; the cell names defined in that
module correspond to architectures discovered by an
evolutionary search described in
https://arxiv.org/abs/1802.01548.
Returns:
A set of tf.HParams suitable for Amoeba Net training.
"""
hparams = imagenet_hparams()
operations, hiddenstate_indices, used_hiddenstates = (
model_specs.get_normal_cell(cell_name))
hparams.add_hparam('normal_cell_operations', operations)
hparams.add_hparam('normal_cell_hiddenstate_indices',
hiddenstate_indices)
hparams.add_hparam('normal_cell_used_hiddenstates',
used_hiddenstates)
operations, hiddenstate_indices, used_hiddenstates = (
model_specs.get_reduction_cell(cell_name))
hparams.add_hparam('reduction_cell_operations',
operations)
hparams.add_hparam('reduction_cell_hiddenstate_indices',
hiddenstate_indices)
hparams.add_hparam('reduction_cell_used_hiddenstates',
used_hiddenstates)
hparams.set_hparam('data_format', 'NHWC')
return hparams
示例12: formatted_hparams
# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import HParams [as 别名]
def formatted_hparams(hparams):
"""Formatts the hparams into a readable string.
Also looks for attributes that have not correctly been added to the hparams
and prints the keys as "bad keys". These bad keys may be left out of iterators
and cirumvent type checking.
Args:
hparams: an HParams instance.
Returns:
A string.
"""
# Look for bad keys (see docstring).
good_keys = set(hparams.values().keys())
bad_keys = []
for key in hparams.__dict__:
if key not in good_keys and not key.startswith('_'):
bad_keys.append(key)
bad_keys.sort()
# Format hparams.
readable_items = [
'%s: %s' % (k, v) for k, v in sorted(hparams.values().iteritems())]
readable_items.append('Bad keys: %s' % ','.join(bad_keys))
readable_string = ('\n'.join(readable_items))
return readable_string