本文整理汇总了Python中common_flags.create_model方法的典型用法代码示例。如果您正苦于以下问题:Python common_flags.create_model方法的具体用法?Python common_flags.create_model怎么用?Python common_flags.create_model使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类common_flags
的用法示例。
在下文中一共展示了common_flags.create_model方法的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=False,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, labels_one_hot=None)
model.create_loss(data, endpoints)
eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir,
logdir=FLAGS.eval_log_dir,
eval_op=eval_ops,
num_evals=FLAGS.num_batches,
eval_interval_secs=FLAGS.eval_interval_secs,
max_number_of_evaluations=FLAGS.number_of_steps,
session_config=session_config)
示例2: main
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def main(_):
prepare_training_dir()
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
hparams = get_training_hparams()
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
device_setter = tf.train.replica_device_setter(
FLAGS.ps_tasks, merge_devices=True)
with tf.device(device_setter):
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=hparams.use_augment_input,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, data.labels_one_hot)
total_loss = model.create_loss(data, endpoints)
model.create_summaries(data, endpoints, dataset.charset, is_training=True)
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
FLAGS.checkpoint_inception)
if FLAGS.show_graph_stats:
logging.info('Total number of weights in the graph: %s',
calculate_graph_metrics())
train(total_loss, init_fn, hparams)
示例3: create_model
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None)
return raw_images, endpoints
示例4: run
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return predictions.tolist()
示例5: load_model
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def load_model(checkpoint, batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(
num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views,
null_code=dataset.null_code,
charset=dataset.charset)
images_placeholder = tf.placeholder(tf.float32,
shape=[batch_size, height, width, 3])
endpoints = model.create_base(images_placeholder, labels_one_hot=None)
init_fn = model.create_init_fn_to_restore(checkpoint)
return images_placeholder, endpoints, init_fn
示例6: run
# 需要导入模块: import common_flags [as 别名]
# 或者: from common_flags import create_model [as 别名]
def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
images_data = load_images(image_path_pattern, batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data})
return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()]