本文整理汇总了Python中data_provider.provide_data方法的典型用法代码示例。如果您正苦于以下问题:Python data_provider.provide_data方法的具体用法?Python data_provider.provide_data怎么用?Python data_provider.provide_data使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类data_provider
的用法示例。
在下文中一共展示了data_provider.provide_data方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_cifar10_train_set
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_cifar10_train_set(self):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4
images, labels, num_samples, num_classes = data_provider.provide_data(
batch_size, dataset_dir)
self.assertEqual(50000, num_samples)
self.assertEqual(10, num_classes)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out, labels_out = sess.run([images, labels])
self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
expected_label_shape = (batch_size, 10)
self.assertEqual(expected_label_shape, labels_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1))
示例2: _test_data_provider_helper
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def _test_data_provider_helper(self, split_name):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/')
batch_size = 3
patch_size = 8
images = data_provider.provide_data(
split_name, batch_size, dataset_dir, patch_size=8)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
self.assertEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
示例3: test_data_provider
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_data_provider(self, mock_provide_custom_data):
batch_size = 2
patch_size = 8
num_domains = 3
images_shape = [batch_size, patch_size, patch_size, 3]
mock_provide_custom_data.return_value = [
tf.zeros(images_shape) for _ in range(num_domains)
]
images, labels = data_provider.provide_data(
image_file_patterns=None, batch_size=batch_size, patch_size=patch_size)
self.assertEqual(num_domains, len(images))
self.assertEqual(num_domains, len(labels))
for label in labels:
self.assertListEqual([batch_size, num_domains], label.shape.as_list())
for image in images:
self.assertListEqual(images_shape, image.shape.as_list())
示例4: test_cifar10_train_set
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_cifar10_train_set(self):
dataset_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/cifar/testdata')
batch_size = 4
images, labels, num_samples, num_classes = data_provider.provide_data(
batch_size, dataset_dir)
self.assertEqual(50000, num_samples)
self.assertEqual(10, num_classes)
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out, labels_out = sess.run([images, labels])
self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
expected_label_shape = (batch_size, 10)
self.assertEqual(expected_label_shape, labels_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1))
示例5: _provide_real_images
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def _provide_real_images(batch_size, **kwargs):
"""Provides real images."""
dataset_name = kwargs.get('dataset_name')
dataset_file_pattern = kwargs.get('dataset_file_pattern')
colors = kwargs['colors']
final_height, final_width = train.make_resolution_schedule(
**kwargs).final_resolutions
if dataset_name is not None:
return data_provider.provide_data(
dataset_name=dataset_name,
split_name='train',
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
elif dataset_file_pattern is not None:
return data_provider.provide_data_from_image_files(
file_pattern=dataset_file_pattern,
batch_size=batch_size,
patch_height=final_height,
patch_width=final_width,
colors=colors)
示例6: test_data_provider
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_data_provider(self, split_name):
dataset_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/image_compression/testdata/')
batch_size = 3
patch_size = 8
images = data_provider.provide_data(
split_name, batch_size, dataset_dir, patch_size=8)
self.assertListEqual([batch_size, patch_size, patch_size, 3],
images.shape.as_list())
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_out = sess.run(images)
self.assertEqual((batch_size, patch_size, patch_size, 3),
images_out.shape)
# Check range.
self.assertTrue(np.all(np.abs(images_out) <= 1.0))
示例7: _get_real_data
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def _get_real_data(num_images_generated, dataset_dir):
"""Get real images."""
data, _, _, num_classes = data_provider.provide_data(
num_images_generated, dataset_dir)
return data, num_classes
示例8: test_mnist_data_reading
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_mnist_data_reading(self):
dataset_dir = os.path.join(
tf.flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/mnist/testdata')
batch_size = 5
images, labels, num_samples = data_provider.provide_data(
'test', batch_size, dataset_dir)
self.assertEqual(num_samples, 10000)
with self.test_session() as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images, labels = sess.run([images, labels])
self.assertEqual(images.shape, (batch_size, 28, 28, 1))
self.assertEqual(labels.shape, (batch_size, 10))
示例9: main
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, run_eval_loop=True):
# Fetch real images.
with tf.name_scope('inputs'):
real_images, _, _ = data_provider.provide_data(
'train', FLAGS.num_images_generated, FLAGS.dataset_dir)
image_write_ops = None
if FLAGS.eval_real_images:
tf.summary.scalar('MNIST_Classifier_score',
util.mnist_score(real_images, FLAGS.classifier_filename))
else:
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('Generator'):
images = networks.unconditional_generator(
tf.random_normal([FLAGS.num_images_generated, FLAGS.noise_dims]))
tf.summary.scalar('MNIST_Frechet_distance',
util.mnist_frechet_distance(
real_images, images, FLAGS.classifier_filename))
tf.summary.scalar('MNIST_Classifier_score',
util.mnist_score(images, FLAGS.classifier_filename))
if FLAGS.num_images_generated >= 100:
reshaped_images = tfgan.eval.image_reshaper(
images[:100, ...], num_cols=10)
uint8_images = data_provider.float_image_to_uint8(reshaped_images)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'unconditional_gan.png'),
tf.image.encode_png(uint8_images[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
示例10: main
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, run_eval_loop=True):
# Fetch real images.
with tf.name_scope('inputs'):
real_images, _, _ = data_provider.provide_data(
'train', FLAGS.num_images_generated, FLAGS.dataset_dir)
image_write_ops = None
if FLAGS.eval_real_images:
tf.summary.scalar('MNIST_Classifier_score',
util.mnist_score(real_images, FLAGS.classifier_filename))
else:
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('Generator'):
images = networks.unconditional_generator(
tf.random_normal([FLAGS.num_images_generated, FLAGS.noise_dims]),
is_training=False)
tf.summary.scalar('MNIST_Frechet_distance',
util.mnist_frechet_distance(
real_images, images, FLAGS.classifier_filename))
tf.summary.scalar('MNIST_Classifier_score',
util.mnist_score(images, FLAGS.classifier_filename))
if FLAGS.num_images_generated >= 100 and FLAGS.write_to_disk:
reshaped_images = tfgan.eval.image_reshaper(
images[:100, ...], num_cols=10)
uint8_images = data_provider.float_image_to_uint8(reshaped_images)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'unconditional_gan.png'),
tf.image.encode_png(uint8_images[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
示例11: test_mnist_data_reading
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_mnist_data_reading(self):
dataset_dir = os.path.join(
flags.FLAGS.test_srcdir,
'google3/third_party/tensorflow_models/gan/mnist/testdata')
batch_size = 5
images, labels, num_samples = data_provider.provide_data(
'test', batch_size, dataset_dir)
self.assertEqual(num_samples, 10000)
with self.test_session() as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images, labels = sess.run([images, labels])
self.assertEqual(images.shape, (batch_size, 28, 28, 1))
self.assertEqual(labels.shape, (batch_size, 10))
示例12: test_provide_data
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_provide_data(self):
images = data_provider.provide_data(
'mnist',
'train',
dataset_dir=self.testdata_dir,
batch_size=2,
shuffle=False,
patch_height=3,
patch_width=3,
colors=1)
self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
with self.test_session(use_gpu=True) as sess:
with tf.contrib.slim.queues.QueueRunners(sess):
images_np = sess.run(images)
self.assertEqual(images_np.shape, (2, 3, 3, 1))
示例13: main
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, run_eval_loop=True):
with tf.name_scope('inputs'):
images = data_provider.provide_data(
'validation', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
patch_size=FLAGS.patch_size)
# In order for variables to load, use the same variable scope as in the
# train job.
with tf.variable_scope('generator'):
reconstructions, _, prebinary = networks.compression_model(
images,
num_bits=FLAGS.bits_per_patch,
depth=FLAGS.model_depth,
is_training=False)
summaries.add_reconstruction_summaries(images, reconstructions, prebinary)
# Visualize losses.
pixel_loss_per_example = tf.reduce_mean(
tf.abs(images - reconstructions), axis=[1, 2, 3])
pixel_loss = tf.reduce_mean(pixel_loss_per_example)
tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example)
tf.summary.scalar('pixel_l1_loss', pixel_loss)
# Create ops to write images to disk.
uint8_images = data_provider.float_image_to_uint8(images)
uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions)
uint8_reshaped = summaries.stack_images(uint8_images, uint8_reconstructions)
image_write_ops = tf.write_file(
'%s/%s'% (FLAGS.eval_dir, 'compression.png'),
tf.image.encode_png(uint8_reshaped[0]))
# For unit testing, use `run_eval_loop=False`.
if not run_eval_loop: return
tf.contrib.training.evaluate_repeatedly(
FLAGS.checkpoint_dir,
master=FLAGS.master,
hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
tf.contrib.training.StopAfterNEvalsHook(1)],
eval_ops=image_write_ops,
max_number_of_evaluations=FLAGS.max_number_of_evaluations)
示例14: main
# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, override_generator_fn=None, override_discriminator_fn=None):
# Create directories if not exist.
if not tf.gfile.Exists(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
# Make sure steps integers are consistent.
if FLAGS.max_number_of_steps % FLAGS.steps_per_eval != 0:
raise ValueError('`max_number_of_steps` must be divisible by '
'`steps_per_eval`.')
# Create optimizers.
gen_opt, dis_opt = _get_optimizer(FLAGS.generator_lr, FLAGS.discriminator_lr)
# Create estimator.
# (joelshor): Add optional distribution strategy here.
stargan_estimator = tfgan.estimator.StarGANEstimator(
generator_fn=override_generator_fn or network.generator,
discriminator_fn=override_discriminator_fn or network.discriminator,
loss_fn=tfgan.stargan_loss,
generator_optimizer=gen_opt,
discriminator_optimizer=dis_opt,
get_hooks_fn=tfgan.get_sequential_train_hooks(_define_train_step()),
add_summaries=tfgan.estimator.SummaryType.IMAGES)
# Get input function for training and test images.
train_input_fn = lambda: data_provider.provide_data( # pylint:disable=g-long-lambda
FLAGS.image_file_patterns, FLAGS.batch_size, FLAGS.patch_size)
test_images_np, _ = data_provider.provide_celeba_test_set()
filename_str = os.path.join(FLAGS.output_dir, 'summary_image_%i.png')
# Periodically train and write prediction output to disk.
cur_step = 0
while cur_step < FLAGS.max_number_of_steps:
stargan_estimator.train(train_input_fn, steps=FLAGS.steps_per_eval)
cur_step += FLAGS.steps_per_eval
summary_img = _get_summary_image(stargan_estimator, test_images_np)
_write_to_disk(summary_img, filename_str % cur_step)