本文整理汇总了Python中data_provider.get_data方法的典型用法代码示例。如果您正苦于以下问题:Python data_provider.get_data方法的具体用法?Python data_provider.get_data怎么用?Python data_provider.get_data使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data_provider
的用法示例。
在下文中一共展示了data_provider.get_data方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_provided_data_has_correct_shape
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def test_provided_data_has_correct_shape(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=None)
with self.test_session() as sess, queues.QueueRunners(sess):
images_np, labels_np = sess.run([data.images, data.labels_one_hot])
self.assertEqual(images_np.shape, (batch_size, 150, 600, 3))
self.assertEqual(labels_np.shape, (batch_size, 37, 134))
示例2: test_optionally_applies_central_crop
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def test_optionally_applies_central_crop(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=(500, 100))
with self.test_session() as sess, queues.QueueRunners(sess):
images_np = sess.run(data.images)
self.assertEqual(images_np.shape, (batch_size, 100, 500, 3))
示例3: main
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=False,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, labels_one_hot=None)
model.create_loss(data, endpoints)
eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir,
logdir=FLAGS.eval_log_dir,
eval_op=eval_ops,
num_evals=FLAGS.num_batches,
eval_interval_secs=FLAGS.eval_interval_secs,
max_number_of_evaluations=FLAGS.number_of_steps,
session_config=session_config)
示例4: main
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def main(_):
prepare_training_dir()
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
hparams = get_training_hparams()
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
device_setter = tf.train.replica_device_setter(
FLAGS.ps_tasks, merge_devices=True)
with tf.device(device_setter):
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=hparams.use_augment_input,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, data.labels_one_hot)
total_loss = model.create_loss(data, endpoints)
model.create_summaries(data, endpoints, dataset.charset, is_training=True)
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
FLAGS.checkpoint_inception)
if FLAGS.show_graph_stats:
logging.info('Total number of weights in the graph: %s',
calculate_graph_metrics())
train(total_loss, init_fn, hparams)