本文整理汇总了Python中common_flags.get_crop_size方法的典型用法代码示例。如果您正苦于以下问题:Python common_flags.get_crop_size方法的具体用法?Python common_flags.get_crop_size怎么用?Python common_flags.get_crop_size使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类common_flags
的用法示例。
在下文中一共展示了common_flags.get_crop_size方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import get_crop_size [as 别名]
def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=False,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, labels_one_hot=None)
model.create_loss(data, endpoints)
eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir,
logdir=FLAGS.eval_log_dir,
eval_op=eval_ops,
num_evals=FLAGS.num_batches,
eval_interval_secs=FLAGS.eval_interval_secs,
max_number_of_evaluations=FLAGS.number_of_steps,
session_config=session_config)
示例2: main
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import get_crop_size [as 别名]
def main(_):
prepare_training_dir()
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
hparams = get_training_hparams()
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
device_setter = tf.train.replica_device_setter(
FLAGS.ps_tasks, merge_devices=True)
with tf.device(device_setter):
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=hparams.use_augment_input,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, data.labels_one_hot)
total_loss = model.create_loss(data, endpoints)
model.create_summaries(data, endpoints, dataset.charset, is_training=True)
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
FLAGS.checkpoint_inception)
if FLAGS.show_graph_stats:
logging.info('Total number of weights in the graph: %s',
calculate_graph_metrics())
train(total_loss, init_fn, hparams)