当前位置: 首页>>代码示例>>Python>>正文


Python registry.problem方法代码示例

本文整理汇总了Python中tensor2tensor.utils.registry.problem方法的典型用法代码示例。如果您正苦于以下问题:Python registry.problem方法的具体用法?Python registry.problem怎么用?Python registry.problem使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在tensor2tensor.utils.registry的用法示例。


在下文中一共展示了registry.problem方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: generate_real_env_data

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_real_env_data(problem_name, agent_policy_path, hparams, data_dir,
                           tmp_dir, autoencoder_path=None, eval_phase=False):
  """Run the agent against the real environment and return mean reward."""
  tf.gfile.MakeDirs(data_dir)
  with temporary_flags({
      "problem": problem_name,
      "agent_policy_path": agent_policy_path,
      "autoencoder_path": autoencoder_path,
  }):
    gym_problem = registry.problem(problem_name)
    gym_problem.settable_num_steps = hparams.true_env_generator_num_steps
    gym_problem.settable_eval_phase = eval_phase
    gym_problem.generate_data(data_dir, tmp_dir)
    mean_reward = None
    if gym_problem.statistics.number_of_dones:
      mean_reward = (gym_problem.statistics.sum_of_rewards /
                     gym_problem.statistics.number_of_dones)

  return mean_reward 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:21,代码来源:model_rl_experiment.py

示例2: evaluate_world_model

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def evaluate_world_model(simulated_problem_name, problem_name, hparams,
                         world_model_dir, epoch_data_dir, tmp_dir):
  """Generate simulated environment data and return reward accuracy."""
  gym_simulated_problem = registry.problem(simulated_problem_name)
  sim_steps = hparams.simulated_env_generator_num_steps
  gym_simulated_problem.settable_num_steps = sim_steps
  with temporary_flags({
      "problem": problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "data_dir": epoch_data_dir,
      "output_dir": world_model_dir,
  }):
    gym_simulated_problem.generate_data(epoch_data_dir, tmp_dir)
  n = max(1., gym_simulated_problem.statistics.number_of_dones)
  model_reward_accuracy = (
      gym_simulated_problem.statistics.successful_episode_reward_predictions
      / float(n))
  old_path = os.path.join(epoch_data_dir, "debug_frames_sim")
  new_path = os.path.join(epoch_data_dir, "debug_frames_sim_eval")
  if not tf.gfile.Exists(new_path):
    tf.gfile.Rename(old_path, new_path)
  return model_reward_accuracy 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:25,代码来源:model_rl_experiment.py

示例3: combine_training_data

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def combine_training_data(problem, final_data_dir, old_data_dirs,
                          copy_last_eval_set=True):
  """Add training data from old_data_dirs into final_data_dir."""
  for i, data_dir in enumerate(old_data_dirs):
    suffix = os.path.basename(data_dir)
    # Glob train files in old data_dir
    old_train_files = tf.gfile.Glob(
        problem.filepattern(data_dir, tf.estimator.ModeKeys.TRAIN))
    if (i + 1) == len(old_data_dirs) and copy_last_eval_set:
      old_train_files += tf.gfile.Glob(
          problem.filepattern(data_dir, tf.estimator.ModeKeys.EVAL))
    for fname in old_train_files:
      # Move them to the new data_dir with a suffix
      # Since the data is read based on a prefix filepattern, adding the suffix
      # should be fine.
      new_fname = os.path.join(final_data_dir,
                               os.path.basename(fname) + "." + suffix)
      if not tf.gfile.Exists(new_fname):
        tf.gfile.Copy(fname, new_fname) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:21,代码来源:model_rl_experiment.py

示例4: generate_data_for_problem

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_data_for_problem(problem):
  """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
  training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

  num_shards = FLAGS.num_shards or 10
  tf.logging.info("Generating training data for %s.", problem)
  train_output_files = generator_utils.train_data_filenames(
      problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
  generator_utils.generate_files(training_gen(), train_output_files,
                                 FLAGS.max_cases)
  tf.logging.info("Generating development data for %s.", problem)
  dev_output_files = generator_utils.dev_data_filenames(
      problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1)
  generator_utils.generate_files(dev_gen(), dev_output_files)
  all_output_files = train_output_files + dev_output_files
  generator_utils.shuffle_dataset(all_output_files) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:18,代码来源:t2t_datagen.py

示例5: generate_data_for_registered_problem

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_data_for_registered_problem(problem_name):
  """Generate data for a registered problem."""
  tf.logging.info("Generating data for %s.", problem_name)
  if FLAGS.num_shards:
    raise ValueError("--num_shards should not be set for registered Problem.")
  problem = registry.problem(problem_name)
  task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  if task_id is None and problem.multiprocess_generate:
    if FLAGS.task_id_start != -1:
      assert FLAGS.task_id_end != -1
      task_id_start = FLAGS.task_id_start
      task_id_end = FLAGS.task_id_end
    else:
      task_id_start = 0
      task_id_end = problem.num_generate_tasks
    pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes)
    problem.prepare_to_generate(data_dir, tmp_dir)
    args = [(problem_name, data_dir, tmp_dir, task_id)
            for task_id in range(task_id_start, task_id_end)]
    pool.map(generate_data_in_process, args)
  else:
    problem.generate_data(data_dir, tmp_dir, task_id) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:26,代码来源:t2t_datagen.py

示例6: create_experiment_fn

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def create_experiment_fn(**kwargs):
  return trainer_lib.create_experiment_fn(
      model_name=FLAGS.model,
      problem_name=FLAGS.problem,
      data_dir=os.path.expanduser(FLAGS.data_dir),
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps,
      min_eval_frequency=FLAGS.local_eval_frequency,
      schedule=FLAGS.schedule,
      eval_throttle_seconds=FLAGS.eval_throttle_seconds,
      export=FLAGS.export_saved_model,
      decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
      use_tfdbg=FLAGS.tfdbg,
      use_dbgprofile=FLAGS.dbgprofile,
      eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
      eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
      eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
      eval_early_stopping_metric_minimize=FLAGS.
      eval_early_stopping_metric_minimize,
      use_tpu=FLAGS.use_tpu,
      **kwargs) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:23,代码来源:t2t_trainer.py

示例7: decode

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def decode(estimator, hparams, decode_hp):
  """Decode from estimator. Interactive, from file, or from dataset."""
  if FLAGS.decode_interactive:
    if estimator.config.use_tpu:
      raise ValueError("TPU can only decode from dataset.")
    decoding.decode_interactively(estimator, hparams, decode_hp,
                                  checkpoint_path=FLAGS.checkpoint_path)
  elif FLAGS.decode_from_file:
    if estimator.config.use_tpu:
      raise ValueError("TPU can only decode from dataset.")
    decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
                              decode_hp, FLAGS.decode_to_file,
                              checkpoint_path=FLAGS.checkpoint_path)
    if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
      ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
      os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
  else:
    decoding.decode_from_dataset(
        estimator,
        FLAGS.problem,
        hparams,
        decode_hp,
        decode_to_file=FLAGS.decode_to_file,
        dataset_split="test" if FLAGS.eval_use_test_set else None) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:26,代码来源:t2t_decoder.py

示例8: testBasicExampleReading

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def testBasicExampleReading(self):
    dataset = self.problem.dataset(
        tf.estimator.ModeKeys.TRAIN,
        data_dir=self.data_dir,
        shuffle_files=False)
    examples = dataset.make_one_shot_iterator().get_next()
    with tf.train.MonitoredSession() as sess:
      # Check that there are multiple examples that have the right fields of the
      # right type (lists of int/float).
      for _ in range(10):
        ex_val = sess.run(examples)
        inputs, targets, floats = (ex_val["inputs"], ex_val["targets"],
                                   ex_val["floats"])
        self.assertEqual(np.int64, inputs.dtype)
        self.assertEqual(np.int64, targets.dtype)
        self.assertEqual(np.float32, floats.dtype)
        for field in [inputs, targets, floats]:
          self.assertGreater(len(field), 0) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:20,代码来源:data_reader_test.py

示例9: testMultiModel

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def testMultiModel(self):
    x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
    y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
    hparams = multimodel.multimodel_tiny()
    hparams.add_hparam("data_dir", "")
    problem = registry.problem("image_cifar10")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem_hparams = p_hparams
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(x, dtype=tf.int32),
          "targets": tf.constant(y, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = multimodel.MultiModel(
          hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (3, 1, 1, 1, 10)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:22,代码来源:multimodel_test.py

示例10: _test_img2img_transformer

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def _test_img2img_transformer(self, net):
    batch_size = 3
    hparams = image_transformer_2d.img2img_transformer2d_tiny()
    hparams.data_dir = ""
    p_hparams = registry.problem("image_celeba").get_hparams(hparams)
    inputs = np.random.random_integers(0, high=255, size=(3, 4, 4, 3))
    targets = np.random.random_integers(0, high=255, size=(3, 8, 8, 3))
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(inputs, dtype=tf.int32),
          "targets": tf.constant(targets, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = net(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (batch_size, 8, 8, 3, 256)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:20,代码来源:image_transformer_2d_test.py

示例11: testSliceNet

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def testSliceNet(self):
    x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
    y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
    hparams = slicenet.slicenet_params1_tiny()
    hparams.add_hparam("data_dir", "")
    problem = registry.problem("image_cifar10")
    p_hparams = problem.get_hparams(hparams)
    hparams.problem_hparams = p_hparams
    with self.test_session() as session:
      features = {
          "inputs": tf.constant(x, dtype=tf.int32),
          "targets": tf.constant(y, dtype=tf.int32),
          "target_space_id": tf.constant(1, dtype=tf.int32),
      }
      model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
                                p_hparams)
      logits, _ = model(features)
      session.run(tf.global_variables_initializer())
      res = session.run(logits)
    self.assertEqual(res.shape, (3, 1, 1, 1, 10)) 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:22,代码来源:slicenet_test.py

示例12: main

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  validate_flags()
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  problem = registry.problem(FLAGS.problem)
  hparams = tf.contrib.training.HParams(
      data_dir=os.path.expanduser(FLAGS.data_dir))
  problem.get_hparams(hparams)
  request_fn = make_request_fn()
  while True:
    inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
    outputs = serving_utils.predict([inputs], problem, request_fn)
    outputs, = outputs
    output, score = outputs
    print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
    """
    print(print_str.format(inputs=inputs, output=output, score=score))
    if FLAGS.inputs_once:
      break 
开发者ID:akzaidi,项目名称:fine-lm,代码行数:26,代码来源:query.py

示例13: main

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def main(_):

  tf.gfile.MakeDirs(FLAGS.data_dir)
  tf.gfile.MakeDirs(FLAGS.tmp_dir)

  # Create problem if not already defined
  problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game
  if problem_name not in registry.Registries.problems:
    gym_env.register_game(FLAGS.game)

  # Generate
  tf.logging.info("Running %s environment for %d steps for trajectories.",
                  FLAGS.game, FLAGS.num_env_steps)
  problem = registry.problem(problem_name)
  problem.settable_num_steps = FLAGS.num_env_steps
  problem.settable_eval_phase = FLAGS.eval
  problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)

  # Log stats
  if problem.statistics.number_of_dones:
    mean_reward = (problem.statistics.sum_of_rewards /
                   problem.statistics.number_of_dones)
    tf.logging.info("Mean reward: %.2f, Num dones: %d",
                    mean_reward,
                    problem.statistics.number_of_dones) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:27,代码来源:datagen_with_agent.py

示例14: generate_data_for_env_problem

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_data_for_env_problem(problem_name):
  """Generate data for `EnvProblem`s."""
  assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps "
                                               "should be greater than zero")
  assert FLAGS.env_problem_batch_size > 0, ("--env_problem_batch_size should be"
                                            " greather than zero")
  problem = registry.env_problem(problem_name)
  task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  # TODO(msaffar): Handle large values for env_problem_batch_size where we
  #  cannot create that many environments within the same process.
  problem.initialize(batch_size=FLAGS.env_problem_batch_size)
  env_problem_utils.play_env_problem_randomly(
      problem, num_steps=FLAGS.env_problem_max_env_steps)
  problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:18,代码来源:t2t_datagen.py

示例15: main

# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def main(_):
  problem = registry.problem(FLAGS.problem)

  # We make the assumption that the problem is a subclass of Text2TextProblem.
  assert isinstance(problem, text_problems.Text2TextProblem)

  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)

  tf.gfile.MakeDirs(data_dir)
  tf.gfile.MakeDirs(tmp_dir)

  tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir)

  problem.get_or_create_vocab(data_dir, tmp_dir)

  tf.logging.info("Saved vocabulary file: " +
                  os.path.join(data_dir, problem.vocab_filename)) 
开发者ID:tensorflow,项目名称:tensor2tensor,代码行数:20,代码来源:build_vocab.py


注:本文中的tensor2tensor.utils.registry.problem方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。