当前位置: 首页>>代码示例>>Python>>正文


Python tensorflow.flags方法代码示例

本文整理汇总了Python中tensorflow.flags方法的典型用法代码示例。如果您正苦于以下问题:Python tensorflow.flags方法的具体用法?Python tensorflow.flags怎么用?Python tensorflow.flags使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensorflow的用法示例。


在下文中一共展示了tensorflow.flags方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:HKUST-KnowComp,项目名称:R-Net,代码行数:19,代码来源:config.py

示例2: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode, you must choose mode from [train/prepro/debug/test]")
        exit(0) 
开发者ID:l11x0m7,项目名称:Question_Answering_Models,代码行数:19,代码来源:config.py

示例3: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "demo":
        demo(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:NLPLearn,项目名称:QANet,代码行数:21,代码来源:config.py

示例4: get_supervisor

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def get_supervisor(model):
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.model_dir)

    supervisor = tf.train.Supervisor(
        logdir=FLAGS.model_dir,
        is_chief=True,
        saver=saver,
        init_op=set_initial_ops(),
        summary_op=tf.summary.merge_all(),
        summary_writer=summary_writer,
        save_summaries_secs=100,  # TODO: add as flags
        save_model_secs=1000,
        global_step=model.global_step,
    )

    return supervisor 
开发者ID:tokestermw,项目名称:text-gan-tensorflow,代码行数:19,代码来源:train.py

示例5: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu  # 选择一块gpu
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        data_process.prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "examine":
        examine_dev(config)
    elif config.mode == "save_dev":
        save_dev(config)
    elif config.mode == "save_test":
        save_test(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:yuhaitao1994,项目名称:AIchallenger2018_MachineReadingComprehension,代码行数:26,代码来源:config.py

示例6: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu  # 选择一块gpu
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        data_process_addAnswer.prepro(config)
    elif config.mode == "test":
        test(config)
    elif config.mode == "examine":
        examine_dev(config)
    elif config.mode == "save_dev":
        save_dev(config)
    elif config.mode == "save_test":
        save_test(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:yuhaitao1994,项目名称:AIchallenger2018_MachineReadingComprehension,代码行数:20,代码来源:config.py

示例7: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    if config.mode == "train":
        train(config)
    elif config.mode == "prepro":
        prepro(config)
    elif config.mode == "debug":
        config.num_steps = 2
        config.val_num_batches = 1
        config.checkpoint = 1
        config.period = 1
        train(config)
    elif config.mode == "test":
        if config.use_cudnn:
            print("Warning: Due to a known bug in Tensorlfow, the parameters of CudnnGRU may not be properly restored.")
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:IsaacChanghau,项目名称:AmusingPythonCodes,代码行数:21,代码来源:config.py

示例8: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main():
    flags = tf.flags
    flags.DEFINE_integer("latent_dim", 64, "Dimension of latent space.")
    flags.DEFINE_integer("obs_dim", 12288, "Dimension of observation space.")
    flags.DEFINE_integer("batch_size", 60, "Batch size.")
    flags.DEFINE_integer("epochs", 500, "As it said")
    flags.DEFINE_integer("updates_per_epoch", 100, "Really just can set to 1 if you don't like mini-batch.")
    FLAGS = flags.FLAGS

    kwargs = {
        'latent_dim': FLAGS.latent_dim,
        'observation_dim': FLAGS.obs_dim,
        'generator': conv_anime_decoder,
        'obs_distrib': 'Gaussian'
    }
    g = GENERATOR(**kwargs)
    g.load_pretrained("weights/vae_anime/generator")

    z = np.random.normal(size=[FLAGS.batch_size, FLAGS.latent_dim])
    samples = g.e2x(z)
    print samples.shape
    show_samples(samples, 4, 15, [64, 64, 3], name='small_samples', shift=True) 
开发者ID:wuga214,项目名称:IMPLEMENTATION_Variational-Auto-Encoder,代码行数:24,代码来源:sample.py

示例9: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
    config = flags.FLAGS
    if config.mode == "get_vocab":
        get_vocab(config)
    elif config.mode == "prepare":
        prepare(config)
    elif config.mode == "train":
        train(config)
    elif config.mode == "train_rl":
        train_rl(config)
    elif config.mode == "train_qpp":
        train_qpp(config)
    elif config.mode == "train_qap":
        train_qap(config)
    elif config.mode == "train_qqp_qap":
        train_qqp_qap(config)
    elif config.mode == "test":
        test(config)
    else:
        print("Unknown mode")
        exit(0) 
开发者ID:ZhangShiyue,项目名称:QGforQA,代码行数:23,代码来源:config.py

示例10: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  # pylint: disable=unused-variable
  model_dir = os.path.expanduser(FLAGS.model_dir)
  translations_dir = os.path.expanduser(FLAGS.translations_dir)
  source = os.path.expanduser(FLAGS.source)
  tf.gfile.MakeDirs(translations_dir)
  translated_base_file = os.path.join(translations_dir, FLAGS.problem)

  # Copy flags.txt with the original time, so t2t-bleu can report correct
  # relative time.
  flags_path = os.path.join(translations_dir, FLAGS.problem + "-flags.txt")
  if not os.path.exists(flags_path):
    shutil.copy2(os.path.join(model_dir, "flags.txt"), flags_path)

  locals_and_flags = {"FLAGS": FLAGS}
  for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes,
                                            FLAGS.min_steps):
    tf.logging.info("Translating " + model.filename)
    out_file = translated_base_file + "-" + str(model.steps)
    locals_and_flags.update(locals())
    if os.path.exists(out_file):
      tf.logging.info(out_file + " already exists, so skipping it.")
    else:
      tf.logging.info("Translating " + out_file)
      params = (
          "--t2t_usr_dir={FLAGS.t2t_usr_dir} --output_dir={model_dir} "
          "--data_dir={FLAGS.data_dir} --problem={FLAGS.problem} "
          "--decode_hparams=beam_size={FLAGS.beam_size},alpha={FLAGS.alpha} "
          "--model={FLAGS.model} --hparams_set={FLAGS.hparams_set} "
          "--checkpoint_path={model.filename} --decode_from_file={source} "
          "--decode_to_file={out_file} --keep_timestamp"
      ).format(**locals_and_flags)
      command = FLAGS.decoder_command.format(**locals())
      tf.logging.info("Running:\n" + command)
      os.system(command)
  # pylint: enable=unused-variable 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:39,代码来源:t2t_translate_all.py

示例11: save_metadata

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def save_metadata(hparams):
  """Saves FLAGS and hparams to output_dir."""
  output_dir = os.path.expanduser(FLAGS.output_dir)
  if not tf.gfile.Exists(output_dir):
    tf.gfile.MakeDirs(output_dir)

  # Save FLAGS in txt file
  if hasattr(FLAGS, "flags_into_string"):
    flags_str = FLAGS.flags_into_string()
    t2t_flags_str = "\n".join([
        "--%s=%s" % (f.name, f.value)
        for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"]
    ])
  else:
    flags_dict = FLAGS.__dict__["__flags"]
    flags_str = "\n".join(
        ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()])
    t2t_flags_str = None

  flags_txt = os.path.join(output_dir, "flags.txt")
  with tf.gfile.Open(flags_txt, "w") as f:
    f.write(flags_str)

  if t2t_flags_str:
    t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt")
    with tf.gfile.Open(t2t_flags_txt, "w") as f:
      f.write(t2t_flags_str)

  # Save hparams as hparams.json
  hparams_fname = os.path.join(output_dir, "hparams.json")
  with tf.gfile.Open(hparams_fname, "w") as f:
    f.write(hparams.to_json(indent=0, sort_keys=True)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:34,代码来源:t2t_trainer.py

示例12: __init__

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def __init__(self, processor_configuration):
    """Creates the Transformer estimator.

    Args:
      processor_configuration: A ProcessorConfiguration protobuffer with the
        transformer fields populated.
    """
    # Do the pre-setup tensor2tensor requires for flags and configurations.
    transformer_config = processor_configuration["transformer"]
    FLAGS.output_dir = transformer_config["model_dir"]
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    data_dir = os.path.expanduser(transformer_config["data_dir"])

    # Create the basic hyper parameters.
    self.hparams = trainer_lib.create_hparams(
        transformer_config["hparams_set"],
        transformer_config["hparams"],
        data_dir=data_dir,
        problem_name=transformer_config["problem"])

    decode_hp = decoding.decode_hparams()
    decode_hp.add_hparam("shards", 1)
    decode_hp.add_hparam("shard_id", 0)

    # Create the estimator and final hyper parameters.
    self.estimator = trainer_lib.create_estimator(
        transformer_config["model"],
        self.hparams,
        t2t_trainer.create_run_config(self.hparams),
        decode_hparams=decode_hp, use_tpu=False)

    # Fetch the vocabulary and other helpful variables for decoding.
    self.source_vocab = self.hparams.problem_hparams.vocabulary["inputs"]
    self.targets_vocab = self.hparams.problem_hparams.vocabulary["targets"]
    self.const_array_size = 10000

    # Prepare the Transformer's debug data directory.
    run_dirs = sorted(glob.glob(os.path.join("/tmp/t2t_server_dump", "run_*")))
    for run_dir in run_dirs:
      shutil.rmtree(run_dir) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:42,代码来源:transformer_model.py

示例13: validate_flags

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def validate_flags():
  """Validates flags are set to acceptable values."""
  if FLAGS.cloud_mlengine_model_name:
    assert not FLAGS.server
    assert not FLAGS.servable_name
  else:
    assert FLAGS.server
    assert FLAGS.servable_name 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:10,代码来源:query.py

示例14: save_metadata

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def save_metadata(hparams):
  """Saves FLAGS and hparams to output_dir."""
  output_dir = os.path.expanduser(FLAGS.output_dir)
  if not tf.gfile.Exists(output_dir):
    tf.gfile.MakeDirs(output_dir)

  # Save FLAGS in txt file
  if hasattr(FLAGS, "flags_into_string"):
    flags_str = FLAGS.flags_into_string()
    t2t_flags_str = "\n".join([
        "--%s=%s" % (f.name, f.value)
        for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"]
    ])
  else:
    flags_dict = FLAGS.__dict__["__flags"]
    flags_str = "\n".join(
        ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()])
    t2t_flags_str = None

  flags_txt = os.path.join(output_dir, "flags.txt")
  with tf.gfile.Open(flags_txt, "w") as f:
    f.write(flags_str)

  if t2t_flags_str:
    t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt")
    with tf.gfile.Open(t2t_flags_txt, "w") as f:
      f.write(t2t_flags_str)

  # Save hparams as hparams.json
  new_hparams = hparams_lib.copy_hparams(hparams)
  # Modality class is not JSON serializable so remove.
  new_hparams.del_hparam("modality")

  hparams_fname = os.path.join(output_dir, "hparams.json")
  with tf.gfile.Open(hparams_fname, "w") as f:
    f.write(new_hparams.to_json(indent=0, sort_keys=True)) 
开发者ID:yyht,项目名称:BERT,代码行数:38,代码来源:t2t_trainer.py

示例15: main

# 需要导入模块: import tensorflow [as 别名]
# 或者: from tensorflow import flags [as 别名]
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  hparams = trainer_lib.create_hparams(
      FLAGS.hparams_set, FLAGS.hparams, data_dir=FLAGS.data_dir,
      problem_name=FLAGS.problem)

  # set appropriate dataset-split, if flags.eval_use_test_set.
  dataset_split = "test" if FLAGS.eval_use_test_set else None
  dataset_kwargs = {"dataset_split": dataset_split}
  eval_input_fn = hparams.problem.make_estimator_input_fn(
      tf.estimator.ModeKeys.EVAL, hparams, dataset_kwargs=dataset_kwargs)
  config = t2t_trainer.create_run_config(hparams)

  # summary-hook in tf.estimator.EstimatorSpec requires
  # hparams.model_dir to be set.
  hparams.add_hparam("model_dir", config.model_dir)

  estimator = trainer_lib.create_estimator(
      FLAGS.model, hparams, config, use_tpu=FLAGS.use_tpu)
  ckpt_iter = trainer_lib.next_checkpoint(
      hparams.model_dir, FLAGS.eval_timeout_mins)
  for ckpt_path in ckpt_iter:
    predictions = estimator.evaluate(
        eval_input_fn, steps=FLAGS.eval_steps, checkpoint_path=ckpt_path)
    tf.logging.info(predictions) 
开发者ID:yyht,项目名称:BERT,代码行数:30,代码来源:t2t_eval.py


注:本文中的tensorflow.flags方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。