本文整理汇总了Python中utils.util.ParseConfigsToLuaTable方法的典型用法代码示例。如果您正苦于以下问题:Python util.ParseConfigsToLuaTable方法的具体用法?Python util.ParseConfigsToLuaTable怎么用?Python util.ParseConfigsToLuaTable使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类utils.util
的用法示例。
在下文中一共展示了util.ParseConfigsToLuaTable方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from utils import util [as 别名]
# 或者: from utils.util import ParseConfigsToLuaTable [as 别名]
def main(_):
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
num_views = config.data.num_views
validation_records = util.GetFilesRecursively(config.data.validation)
batch_size = config.data.batch_size
checkpointdir = FLAGS.checkpointdir
# If evaluating a specific checkpoint, do that.
if FLAGS.checkpoint_iter:
checkpoint_path = os.path.join(
'%s/model.ckpt-%s' % (checkpointdir, FLAGS.checkpoint_iter))
evaluate_once(
config, checkpointdir, validation_records, checkpoint_path, batch_size,
num_views)
else:
for checkpoint_path in tf.contrib.training.checkpoints_iterator(
checkpointdir):
evaluate_once(
config, checkpointdir, validation_records, checkpoint_path,
batch_size, num_views)
示例2: main
# 需要导入模块: from utils import util [as 别名]
# 或者: from utils.util import ParseConfigsToLuaTable [as 别名]
def main(_):
"""Runs main eval loop."""
# Parse config dict from yaml config files / command line flags.
logdir = FLAGS.logdir
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
# Choose an estimator based on training strategy.
estimator = get_estimator(config, logdir)
# Wait for the first checkpoint file to be written.
while not tf.train.latest_checkpoint(logdir):
tf.logging.info('Waiting for a checkpoint file...')
time.sleep(10)
# Run validation.
while True:
estimator.evaluate()
示例3: main
# 需要导入模块: from utils import util [as 别名]
# 或者: from utils.util import ParseConfigsToLuaTable [as 别名]
def main(_):
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
# Get tables to embed.
query_records_dir = FLAGS.query_records_dir
query_records = util.GetFilesRecursively(query_records_dir)
target_records_dir = FLAGS.target_records_dir
target_records = util.GetFilesRecursively(target_records_dir)
height = config.data.raw_height
width = config.data.raw_width
mode = FLAGS.mode
if mode == 'multi':
# Generate videos where target set is composed of multiple videos.
MultiImitationVideos(query_records, target_records, config,
height, width)
elif mode == 'single':
# Generate videos where target set is a single video.
SingleImitationVideos(query_records, target_records, config,
height, width)
elif mode == 'same':
# Generate videos where target set is the same as query, but diff view.
SameSequenceVideos(query_records, config, height, width)
else:
raise ValueError('Unknown mode %s' % mode)
示例4: main
# 需要导入模块: from utils import util [as 别名]
# 或者: from utils.util import ParseConfigsToLuaTable [as 别名]
def main(_):
"""Runs main training loop."""
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(
FLAGS.config_paths, FLAGS.model_params, save=True, logdir=FLAGS.logdir)
# Choose an estimator based on training strategy.
estimator = get_estimator(config, FLAGS.logdir)
# Run training
estimator.train()
示例5: main
# 需要导入模块: from utils import util [as 别名]
# 或者: from utils.util import ParseConfigsToLuaTable [as 别名]
def main(_):
"""Runs main labeled eval loop."""
# Parse config dict from yaml config files / command line flags.
config = util.ParseConfigsToLuaTable(FLAGS.config_paths, FLAGS.model_params)
# Choose an estimator based on training strategy.
checkpointdir = FLAGS.checkpointdir
estimator = get_estimator(config, checkpointdir)
# Get data configs.
image_attr_keys = config.data.labeled.image_attr_keys
label_attr_keys = config.data.labeled.label_attr_keys
embedding_size = config.embedding_size
num_views = config.data.num_views
k_list = config.val.recall_at_k_list
batch_size = config.data.batch_size
# Get either labeled validation or test tables.
labeled_tables = get_labeled_tables(config)
def input_fn_by_view(view_index):
"""Returns an input_fn for use with a tf.Estimator by view."""
def input_fn():
# Get raw labeled images.
(preprocessed_images, labels,
tasks) = data_providers.labeled_data_provider(
labeled_tables,
estimator.preprocess_data, view_index, image_attr_keys,
label_attr_keys, batch_size=batch_size)
return {
'batch_preprocessed': preprocessed_images,
'tasks': tasks,
'classification_labels': labels,
}, None
return input_fn
# If evaluating a specific checkpoint, do that.
if FLAGS.checkpoint_iter:
checkpoint_path = os.path.join(
'%s/model.ckpt-%s' % (checkpointdir, FLAGS.checkpoint_iter))
evaluate_once(
estimator, input_fn_by_view, batch_size, checkpoint_path,
label_attr_keys, embedding_size, num_views, k_list)
else:
for checkpoint_path in tf.contrib.training.checkpoints_iterator(
checkpointdir):
evaluate_once(
estimator, input_fn_by_view, batch_size, checkpoint_path,
label_attr_keys, embedding_size, num_views, k_list)