本文整理汇总了Python中tensorflow.compat.v1.flags方法的典型用法代码示例。如果您正苦于以下问题:Python v1.flags方法的具体用法?Python v1.flags怎么用?Python v1.flags使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensorflow.compat.v1
的用法示例。
在下文中一共展示了v1.flags方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import flags [as 别名]
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# pylint: disable=unused-variable
model_dir = os.path.expanduser(FLAGS.model_dir)
translations_dir = os.path.expanduser(FLAGS.translations_dir)
source = os.path.expanduser(FLAGS.source)
tf.gfile.MakeDirs(translations_dir)
translated_base_file = os.path.join(translations_dir, FLAGS.problem)
# Copy flags.txt with the original time, so t2t-bleu can report correct
# relative time.
flags_path = os.path.join(translations_dir, FLAGS.problem + "-flags.txt")
if not os.path.exists(flags_path):
shutil.copy2(os.path.join(model_dir, "flags.txt"), flags_path)
locals_and_flags = {"FLAGS": FLAGS}
for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes,
FLAGS.min_steps):
tf.logging.info("Translating " + model.filename)
out_file = translated_base_file + "-" + str(model.steps)
locals_and_flags.update(locals())
if os.path.exists(out_file):
tf.logging.info(out_file + " already exists, so skipping it.")
else:
tf.logging.info("Translating " + out_file)
params = (
"--t2t_usr_dir={FLAGS.t2t_usr_dir} --output_dir={model_dir} "
"--data_dir={FLAGS.data_dir} --problem={FLAGS.problem} "
"--decode_hparams=beam_size={FLAGS.beam_size},alpha={FLAGS.alpha} "
"--model={FLAGS.model} --hparams_set={FLAGS.hparams_set} "
"--checkpoint_path={model.filename} --decode_from_file={source} "
"--decode_to_file={out_file} --keep_timestamp"
).format(**locals_and_flags)
command = FLAGS.decoder_command.format(**locals())
tf.logging.info("Running:\n" + command)
os.system(command)
# pylint: enable=unused-variable
示例2: save_metadata
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import flags [as 别名]
def save_metadata(hparams):
"""Saves FLAGS and hparams to output_dir."""
output_dir = os.path.expanduser(FLAGS.output_dir)
if not tf.gfile.Exists(output_dir):
tf.gfile.MakeDirs(output_dir)
# Save FLAGS in txt file
if hasattr(FLAGS, "flags_into_string"):
flags_str = FLAGS.flags_into_string()
t2t_flags_str = "\n".join([
"--%s=%s" % (f.name, f.value)
for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"]
])
else:
flags_dict = FLAGS.__dict__["__flags"]
flags_str = "\n".join(
["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()])
t2t_flags_str = None
flags_txt = os.path.join(output_dir, "flags.txt")
with tf.gfile.Open(flags_txt, "w") as f:
f.write(flags_str)
if t2t_flags_str:
t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt")
with tf.gfile.Open(t2t_flags_txt, "w") as f:
f.write(t2t_flags_str)
# Save hparams as hparams.json
new_hparams = hparams_lib.copy_hparams(hparams)
# Modality class is not JSON serializable so remove.
new_hparams.del_hparam("modality")
hparams_fname = os.path.join(output_dir, "hparams.json")
with tf.gfile.Open(hparams_fname, "w") as f:
f.write(new_hparams.to_json(indent=0, sort_keys=True))
示例3: main
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import flags [as 别名]
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
hparams = trainer_lib.create_hparams(
FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir,
problem_name=FLAGS.problem)
# set appropriate dataset-split, if flags.eval_use_test_set.
dataset_split = "test" if FLAGS.eval_use_test_set else None
dataset_kwargs = {"dataset_split": dataset_split}
eval_input_fn = hparams.problem.make_estimator_input_fn(
tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
config = t2t_trainer.create_run_config(hparams)
# summary-hook in tf.estimator.EstimatorSpec requires
# hparams.model_dir to be set.
hparams.add_hparam("model_dir", config.model_dir)
estimator = trainer_lib.create_estimator(
FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu)
ckpt_iter = trainer_lib.next_checkpoint(
hparams.model_dir, FLAGS.eval_timeout_mins)
for ckpt_path in ckpt_iter:
predictions = estimator.evaluate(
eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path)
tf.logging.info(predictions)
示例4: __init__
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import flags [as 别名]
def __init__(self, processor_configuration):
"""Creates the Transformer estimator.
Args:
processor_configuration: A ProcessorConfiguration protobuffer with the
transformer fields populated.
"""
# Do the pre-setup tensor2tensor requires for flags and configurations.
transformer_config = processor_configuration["transformer"]
FLAGS.output_dir = transformer_config["model_dir"]
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
data_dir = os.path.expanduser(transformer_config["data_dir"])
# Create the basic hyper parameters.
self.hparams = trainer_lib.create_hparams(
transformer_config["hparams_set"],
transformer_config["hparams"],
data_dir=data_dir,
problem_name=transformer_config["problem"])
decode_hp = decoding.decode_hparams()
decode_hp.add_hparam("shards", 1)
decode_hp.add_hparam("shard_id", 0)
# Create the estimator and final hyper parameters.
self.estimator = trainer_lib.create_estimator(
transformer_config["model"],
self.hparams,
t2t_trainer.create_run_config(self.hparams),
decode_hparams=decode_hp, use_tpu=False)
# Fetch the vocabulary and other helpful variables for decoding.
self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
self.const_array_size = 10000
# Prepare the Transformer's debug data directory.
run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
for run_dir in run_dirs:
shutil.rmtree(run_dir)
示例5: validate_flags
# 需要导入模块: from tensorflow.compat import v1 [as 别名]
# 或者: from tensorflow.compat.v1 import flags [as 别名]
def validate_flags():
"""Validates flags are set to acceptable values."""
if FLAGS.cloud_mlengine_model_name:
assert not FLAGS.server
assert not FLAGS.servable_name
else:
assert FLAGS.server
assert FLAGS.servable_name