本文整理匯總了Python中data_provider.get_data方法的典型用法代碼示例。如果您正苦於以下問題:Python data_provider.get_data方法的具體用法?Python data_provider.get_data怎麽用?Python data_provider.get_data使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類data_provider
的用法示例。
在下文中一共展示了data_provider.get_data方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: test_provided_data_has_correct_shape
# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def test_provided_data_has_correct_shape(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=None)
with self.test_session() as sess, queues.QueueRunners(sess):
images_np, labels_np = sess.run([data.images, data.labels_one_hot])
self.assertEqual(images_np.shape, (batch_size, 150, 600, 3))
self.assertEqual(labels_np.shape, (batch_size, 37, 134))
示例2: test_optionally_applies_central_crop
# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def test_optionally_applies_central_crop(self):
batch_size = 4
data = data_provider.get_data(
dataset=datasets.fsns_test.get_test_split(),
batch_size=batch_size,
augment=True,
central_crop_size=(500, 100))
with self.test_session() as sess, queues.QueueRunners(sess):
images_np = sess.run(data.images)
self.assertEqual(images_np.shape, (batch_size, 100, 500, 3))
示例3: main
# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=False,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, labels_one_hot=None)
model.create_loss(data, endpoints)
eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop(
master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir,
logdir=FLAGS.eval_log_dir,
eval_op=eval_ops,
num_evals=FLAGS.num_batches,
eval_interval_secs=FLAGS.eval_interval_secs,
max_number_of_evaluations=FLAGS.number_of_steps,
session_config=session_config)
示例4: main
# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def main(_):
prepare_training_dir()
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
hparams = get_training_hparams()
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
device_setter = tf.train.replica_device_setter(
FLAGS.ps_tasks, merge_devices=True)
with tf.device(device_setter):
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=hparams.use_augment_input,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, data.labels_one_hot)
total_loss = model.create_loss(data, endpoints)
model.create_summaries(data, endpoints, dataset.charset, is_training=True)
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
FLAGS.checkpoint_inception)
if FLAGS.show_graph_stats:
logging.info('Total number of weights in the graph: %s',
calculate_graph_metrics())
train(total_loss, init_fn, hparams)