本文整理汇总了Python中tensor2tensor.utils.registry.problem方法的典型用法代码示例。如果您正苦于以下问题:Python registry.problem方法的具体用法?Python registry.problem怎么用?Python registry.problem使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类tensor2tensor.utils.registry
的用法示例。
在下文中一共展示了registry.problem方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: generate_real_env_data
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_real_env_data(problem_name, agent_policy_path, hparams, data_dir,
tmp_dir, autoencoder_path=None, eval_phase=False):
"""Run the agent against the real environment and return mean reward."""
tf.gfile.MakeDirs(data_dir)
with temporary_flags({
"problem": problem_name,
"agent_policy_path": agent_policy_path,
"autoencoder_path": autoencoder_path,
}):
gym_problem = registry.problem(problem_name)
gym_problem.settable_num_steps = hparams.true_env_generator_num_steps
gym_problem.settable_eval_phase = eval_phase
gym_problem.generate_data(data_dir, tmp_dir)
mean_reward = None
if gym_problem.statistics.number_of_dones:
mean_reward = (gym_problem.statistics.sum_of_rewards /
gym_problem.statistics.number_of_dones)
return mean_reward
示例2: evaluate_world_model
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def evaluate_world_model(simulated_problem_name, problem_name, hparams,
world_model_dir, epoch_data_dir, tmp_dir):
"""Generate simulated environment data and return reward accuracy."""
gym_simulated_problem = registry.problem(simulated_problem_name)
sim_steps = hparams.simulated_env_generator_num_steps
gym_simulated_problem.settable_num_steps = sim_steps
with temporary_flags({
"problem": problem_name,
"model": hparams.generative_model,
"hparams_set": hparams.generative_model_params,
"data_dir": epoch_data_dir,
"output_dir": world_model_dir,
}):
gym_simulated_problem.generate_data(epoch_data_dir, tmp_dir)
n = max(1., gym_simulated_problem.statistics.number_of_dones)
model_reward_accuracy = (
gym_simulated_problem.statistics.successful_episode_reward_predictions
/ float(n))
old_path = os.path.join(epoch_data_dir, "debug_frames_sim")
new_path = os.path.join(epoch_data_dir, "debug_frames_sim_eval")
if not tf.gfile.Exists(new_path):
tf.gfile.Rename(old_path, new_path)
return model_reward_accuracy
示例3: combine_training_data
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def combine_training_data(problem, final_data_dir, old_data_dirs,
copy_last_eval_set=True):
"""Add training data from old_data_dirs into final_data_dir."""
for i, data_dir in enumerate(old_data_dirs):
suffix = os.path.basename(data_dir)
# Glob train files in old data_dir
old_train_files = tf.gfile.Glob(
problem.filepattern(data_dir, tf.estimator.ModeKeys.TRAIN))
if (i + 1) == len(old_data_dirs) and copy_last_eval_set:
old_train_files += tf.gfile.Glob(
problem.filepattern(data_dir, tf.estimator.ModeKeys.EVAL))
for fname in old_train_files:
# Move them to the new data_dir with a suffix
# Since the data is read based on a prefix filepattern, adding the suffix
# should be fine.
new_fname = os.path.join(final_data_dir,
os.path.basename(fname) + "." + suffix)
if not tf.gfile.Exists(new_fname):
tf.gfile.Copy(fname, new_fname)
示例4: generate_data_for_problem
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_data_for_problem(problem):
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
num_shards = FLAGS.num_shards or 10
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
generator_utils.generate_files(training_gen(), train_output_files,
FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
dev_output_files = generator_utils.dev_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1)
generator_utils.generate_files(dev_gen(), dev_output_files)
all_output_files = train_output_files + dev_output_files
generator_utils.shuffle_dataset(all_output_files)
示例5: generate_data_for_registered_problem
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_data_for_registered_problem(problem_name):
"""Generate data for a registered problem."""
tf.logging.info("Generating data for %s.", problem_name)
if FLAGS.num_shards:
raise ValueError("--num_shards should not be set for registered Problem.")
problem = registry.problem(problem_name)
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
data_dir = os.path.expanduser(FLAGS.data_dir)
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
if task_id is None and problem.multiprocess_generate:
if FLAGS.task_id_start != -1:
assert FLAGS.task_id_end != -1
task_id_start = FLAGS.task_id_start
task_id_end = FLAGS.task_id_end
else:
task_id_start = 0
task_id_end = problem.num_generate_tasks
pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes)
problem.prepare_to_generate(data_dir, tmp_dir)
args = [(problem_name, data_dir, tmp_dir, task_id)
for task_id in range(task_id_start, task_id_end)]
pool.map(generate_data_in_process, args)
else:
problem.generate_data(data_dir, tmp_dir, task_id)
示例6: create_experiment_fn
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def create_experiment_fn(**kwargs):
return trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=FLAGS.problem,
data_dir=os.path.expanduser(FLAGS.data_dir),
train_steps=FLAGS.train_steps,
eval_steps=FLAGS.eval_steps,
min_eval_frequency=FLAGS.local_eval_frequency,
schedule=FLAGS.schedule,
eval_throttle_seconds=FLAGS.eval_throttle_seconds,
export=FLAGS.export_saved_model,
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu,
**kwargs)
示例7: decode
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def decode(estimator, hparams, decode_hp):
"""Decode from estimator. Interactive, from file, or from dataset."""
if FLAGS.decode_interactive:
if estimator.config.use_tpu:
raise ValueError("TPU can only decode from dataset.")
decoding.decode_interactively(estimator, hparams, decode_hp,
checkpoint_path=FLAGS.checkpoint_path)
elif FLAGS.decode_from_file:
if estimator.config.use_tpu:
raise ValueError("TPU can only decode from dataset.")
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
decode_hp, FLAGS.decode_to_file,
checkpoint_path=FLAGS.checkpoint_path)
if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + ".index")
os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
else:
decoding.decode_from_dataset(
estimator,
FLAGS.problem,
hparams,
decode_hp,
decode_to_file=FLAGS.decode_to_file,
dataset_split="test" if FLAGS.eval_use_test_set else None)
示例8: testBasicExampleReading
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def testBasicExampleReading(self):
dataset = self.problem.dataset(
tf.estimator.ModeKeys.TRAIN,
data_dir=self.data_dir,
shuffle_files=False)
examples = dataset.make_one_shot_iterator().get_next()
with tf.train.MonitoredSession() as sess:
# Check that there are multiple examples that have the right fields of the
# right type (lists of int/float).
for _ in range(10):
ex_val = sess.run(examples)
inputs, targets, floats = (ex_val["inputs"], ex_val["targets"],
ex_val["floats"])
self.assertEqual(np.int64, inputs.dtype)
self.assertEqual(np.int64, targets.dtype)
self.assertEqual(np.float32, floats.dtype)
for field in [inputs, targets, floats]:
self.assertGreater(len(field), 0)
示例9: testMultiModel
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def testMultiModel(self):
x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
hparams = multimodel.multimodel_tiny()
hparams.add_hparam("data_dir", "")
problem = registry.problem("image_cifar10")
p_hparams = problem.get_hparams(hparams)
hparams.problem_hparams = p_hparams
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
model = multimodel.MultiModel(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 1, 1, 1, 10))
示例10: _test_img2img_transformer
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def _test_img2img_transformer(self, net):
batch_size = 3
hparams = image_transformer_2d.img2img_transformer2d_tiny()
hparams.data_dir = ""
p_hparams = registry.problem("image_celeba").get_hparams(hparams)
inputs = np.random.random_integers(0, high=255, size=(3, 4, 4, 3))
targets = np.random.random_integers(0, high=255, size=(3, 8, 8, 3))
with self.test_session() as session:
features = {
"inputs": tf.constant(inputs, dtype=tf.int32),
"targets": tf.constant(targets, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
model = net(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (batch_size, 8, 8, 3, 256))
示例11: testSliceNet
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def testSliceNet(self):
x = np.random.random_integers(0, high=255, size=(3, 5, 5, 3))
y = np.random.random_integers(0, high=9, size=(3, 5, 1, 1))
hparams = slicenet.slicenet_params1_tiny()
hparams.add_hparam("data_dir", "")
problem = registry.problem("image_cifar10")
p_hparams = problem.get_hparams(hparams)
hparams.problem_hparams = p_hparams
with self.test_session() as session:
features = {
"inputs": tf.constant(x, dtype=tf.int32),
"targets": tf.constant(y, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
}
model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
p_hparams)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 1, 1, 1, 10))
示例12: main
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
validate_flags()
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
problem = registry.problem(FLAGS.problem)
hparams = tf.contrib.training.HParams(
data_dir=os.path.expanduser(FLAGS.data_dir))
problem.get_hparams(hparams)
request_fn = make_request_fn()
while True:
inputs = FLAGS.inputs_once if FLAGS.inputs_once else input(">> ")
outputs = serving_utils.predict([inputs], problem, request_fn)
outputs, = outputs
output, score = outputs
print_str = """
Input:
{inputs}
Output (Score {score:.3f}):
{output}
"""
print(print_str.format(inputs=inputs, output=output, score=score))
if FLAGS.inputs_once:
break
示例13: main
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def main(_):
tf.gfile.MakeDirs(FLAGS.data_dir)
tf.gfile.MakeDirs(FLAGS.tmp_dir)
# Create problem if not already defined
problem_name = "gym_discrete_problem_with_agent_on_%s" % FLAGS.game
if problem_name not in registry.Registries.problems:
gym_env.register_game(FLAGS.game)
# Generate
tf.logging.info("Running %s environment for %d steps for trajectories.",
FLAGS.game, FLAGS.num_env_steps)
problem = registry.problem(problem_name)
problem.settable_num_steps = FLAGS.num_env_steps
problem.settable_eval_phase = FLAGS.eval
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
# Log stats
if problem.statistics.number_of_dones:
mean_reward = (problem.statistics.sum_of_rewards /
problem.statistics.number_of_dones)
tf.logging.info("Mean reward: %.2f, Num dones: %d",
mean_reward,
problem.statistics.number_of_dones)
示例14: generate_data_for_env_problem
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def generate_data_for_env_problem(problem_name):
"""Generate data for `EnvProblem`s."""
assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps "
"should be greater than zero")
assert FLAGS.env_problem_batch_size > 0, ("--env_problem_batch_size should be"
" greather than zero")
problem = registry.env_problem(problem_name)
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
data_dir = os.path.expanduser(FLAGS.data_dir)
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
# TODO(msaffar): Handle large values for env_problem_batch_size where we
# cannot create that many environments within the same process.
problem.initialize(batch_size=FLAGS.env_problem_batch_size)
env_problem_utils.play_env_problem_randomly(
problem, num_steps=FLAGS.env_problem_max_env_steps)
problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id)
示例15: main
# 需要导入模块: from tensor2tensor.utils import registry [as 别名]
# 或者: from tensor2tensor.utils.registry import problem [as 别名]
def main(_):
problem = registry.problem(FLAGS.problem)
# We make the assumption that the problem is a subclass of Text2TextProblem.
assert isinstance(problem, text_problems.Text2TextProblem)
data_dir = os.path.expanduser(FLAGS.data_dir)
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
tf.gfile.MakeDirs(data_dir)
tf.gfile.MakeDirs(tmp_dir)
tf.logging.info("Saving vocabulary to data_dir: %s" % data_dir)
problem.get_or_create_vocab(data_dir, tmp_dir)
tf.logging.info("Saved vocabulary file: " +
os.path.join(data_dir, problem.vocab_filename))