当前位置: 首页>>代码示例>>Python>>正文


Python data_provider.provide_data方法代码示例

本文整理汇总了Python中data_provider.provide_data方法的典型用法代码示例。如果您正苦于以下问题:Python data_provider.provide_data方法的具体用法?Python data_provider.provide_data怎么用?Python data_provider.provide_data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_provider的用法示例。


在下文中一共展示了data_provider.provide_data方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_cifar10_train_set

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_cifar10_train_set(self):
    dataset_dir = os.path.join(
        tf.flags.FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/cifar/testdata')

    batch_size = 4
    images, labels, num_samples, num_classes = data_provider.provide_data(
        batch_size, dataset_dir)
    self.assertEqual(50000, num_samples)
    self.assertEqual(10, num_classes)
    with self.test_session(use_gpu=True) as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images_out, labels_out = sess.run([images, labels])
        self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
        expected_label_shape = (batch_size, 10)
        self.assertEqual(expected_label_shape, labels_out.shape)
        # Check range.
        self.assertTrue(np.all(np.abs(images_out) <= 1)) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:20,代码来源:data_provider_test.py

示例2: _test_data_provider_helper

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def _test_data_provider_helper(self, split_name):
    dataset_dir = os.path.join(
        tf.flags.FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/image_compression/testdata/')

    batch_size = 3
    patch_size = 8
    images = data_provider.provide_data(
        split_name, batch_size, dataset_dir, patch_size=8)
    self.assertListEqual([batch_size, patch_size, patch_size, 3],
                         images.shape.as_list())

    with self.test_session(use_gpu=True) as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images_out = sess.run(images)
        self.assertEqual((batch_size, patch_size, patch_size, 3),
                         images_out.shape)
        # Check range.
        self.assertTrue(np.all(np.abs(images_out) <= 1.0)) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:21,代码来源:data_provider_test.py

示例3: test_data_provider

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_data_provider(self, mock_provide_custom_data):

    batch_size = 2
    patch_size = 8
    num_domains = 3

    images_shape = [batch_size, patch_size, patch_size, 3]
    mock_provide_custom_data.return_value = [
        tf.zeros(images_shape) for _ in range(num_domains)
    ]

    images, labels = data_provider.provide_data(
        image_file_patterns=None, batch_size=batch_size, patch_size=patch_size)

    self.assertEqual(num_domains, len(images))
    self.assertEqual(num_domains, len(labels))
    for label in labels:
      self.assertListEqual([batch_size, num_domains], label.shape.as_list())
    for image in images:
      self.assertListEqual(images_shape, image.shape.as_list()) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:22,代码来源:data_provider_test.py

示例4: test_cifar10_train_set

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_cifar10_train_set(self):
    dataset_dir = os.path.join(
        flags.FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/cifar/testdata')

    batch_size = 4
    images, labels, num_samples, num_classes = data_provider.provide_data(
        batch_size, dataset_dir)
    self.assertEqual(50000, num_samples)
    self.assertEqual(10, num_classes)
    with self.test_session(use_gpu=True) as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images_out, labels_out = sess.run([images, labels])
        self.assertEqual(images_out.shape, (batch_size, 32, 32, 3))
        expected_label_shape = (batch_size, 10)
        self.assertEqual(expected_label_shape, labels_out.shape)
        # Check range.
        self.assertTrue(np.all(np.abs(images_out) <= 1)) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:20,代码来源:data_provider_test.py

示例5: _provide_real_images

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def _provide_real_images(batch_size, **kwargs):
  """Provides real images."""
  dataset_name = kwargs.get('dataset_name')
  dataset_file_pattern = kwargs.get('dataset_file_pattern')
  colors = kwargs['colors']
  final_height, final_width = train.make_resolution_schedule(
      **kwargs).final_resolutions
  if dataset_name is not None:
    return data_provider.provide_data(
        dataset_name=dataset_name,
        split_name='train',
        batch_size=batch_size,
        patch_height=final_height,
        patch_width=final_width,
        colors=colors)
  elif dataset_file_pattern is not None:
    return data_provider.provide_data_from_image_files(
        file_pattern=dataset_file_pattern,
        batch_size=batch_size,
        patch_height=final_height,
        patch_width=final_width,
        colors=colors) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:24,代码来源:train_main.py

示例6: test_data_provider

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_data_provider(self, split_name):
    dataset_dir = os.path.join(
        flags.FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/image_compression/testdata/')

    batch_size = 3
    patch_size = 8
    images = data_provider.provide_data(
        split_name, batch_size, dataset_dir, patch_size=8)
    self.assertListEqual([batch_size, patch_size, patch_size, 3],
                         images.shape.as_list())

    with self.test_session(use_gpu=True) as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images_out = sess.run(images)
        self.assertEqual((batch_size, patch_size, patch_size, 3),
                         images_out.shape)
        # Check range.
        self.assertTrue(np.all(np.abs(images_out) <= 1.0)) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:21,代码来源:data_provider_test.py

示例7: _get_real_data

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def _get_real_data(num_images_generated, dataset_dir):
  """Get real images."""
  data, _, _, num_classes = data_provider.provide_data(
      num_images_generated, dataset_dir)
  return data, num_classes 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:7,代码来源:eval.py

示例8: test_mnist_data_reading

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_mnist_data_reading(self):
    dataset_dir = os.path.join(
        tf.flags.FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/mnist/testdata')

    batch_size = 5
    images, labels, num_samples = data_provider.provide_data(
        'test', batch_size, dataset_dir)
    self.assertEqual(num_samples, 10000)

    with self.test_session() as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images, labels = sess.run([images, labels])
        self.assertEqual(images.shape, (batch_size, 28, 28, 1))
        self.assertEqual(labels.shape, (batch_size, 10)) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:17,代码来源:data_provider_test.py

示例9: main

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, run_eval_loop=True):
  # Fetch real images.
  with tf.name_scope('inputs'):
    real_images, _, _ = data_provider.provide_data(
        'train', FLAGS.num_images_generated, FLAGS.dataset_dir)

  image_write_ops = None
  if FLAGS.eval_real_images:
    tf.summary.scalar('MNIST_Classifier_score',
                      util.mnist_score(real_images, FLAGS.classifier_filename))
  else:
    # In order for variables to load, use the same variable scope as in the
    # train job.
    with tf.variable_scope('Generator'):
      images = networks.unconditional_generator(
          tf.random_normal([FLAGS.num_images_generated, FLAGS.noise_dims]))
    tf.summary.scalar('MNIST_Frechet_distance',
                      util.mnist_frechet_distance(
                          real_images, images, FLAGS.classifier_filename))
    tf.summary.scalar('MNIST_Classifier_score',
                      util.mnist_score(images, FLAGS.classifier_filename))
    if FLAGS.num_images_generated >= 100:
      reshaped_images = tfgan.eval.image_reshaper(
          images[:100, ...], num_cols=10)
      uint8_images = data_provider.float_image_to_uint8(reshaped_images)
      image_write_ops = tf.write_file(
          '%s/%s'% (FLAGS.eval_dir, 'unconditional_gan.png'),
          tf.image.encode_png(uint8_images[0]))

  # For unit testing, use `run_eval_loop=False`.
  if not run_eval_loop: return
  tf.contrib.training.evaluate_repeatedly(
      FLAGS.checkpoint_dir,
      hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
             tf.contrib.training.StopAfterNEvalsHook(1)],
      eval_ops=image_write_ops,
      max_number_of_evaluations=FLAGS.max_number_of_evaluations) 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:39,代码来源:eval.py

示例10: main

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, run_eval_loop=True):
  # Fetch real images.
  with tf.name_scope('inputs'):
    real_images, _, _ = data_provider.provide_data(
        'train', FLAGS.num_images_generated, FLAGS.dataset_dir)

  image_write_ops = None
  if FLAGS.eval_real_images:
    tf.summary.scalar('MNIST_Classifier_score',
                      util.mnist_score(real_images, FLAGS.classifier_filename))
  else:
    # In order for variables to load, use the same variable scope as in the
    # train job.
    with tf.variable_scope('Generator'):
      images = networks.unconditional_generator(
          tf.random_normal([FLAGS.num_images_generated, FLAGS.noise_dims]),
          is_training=False)
    tf.summary.scalar('MNIST_Frechet_distance',
                      util.mnist_frechet_distance(
                          real_images, images, FLAGS.classifier_filename))
    tf.summary.scalar('MNIST_Classifier_score',
                      util.mnist_score(images, FLAGS.classifier_filename))
    if FLAGS.num_images_generated >= 100 and FLAGS.write_to_disk:
      reshaped_images = tfgan.eval.image_reshaper(
          images[:100, ...], num_cols=10)
      uint8_images = data_provider.float_image_to_uint8(reshaped_images)
      image_write_ops = tf.write_file(
          '%s/%s'% (FLAGS.eval_dir, 'unconditional_gan.png'),
          tf.image.encode_png(uint8_images[0]))

  # For unit testing, use `run_eval_loop=False`.
  if not run_eval_loop: return
  tf.contrib.training.evaluate_repeatedly(
      FLAGS.checkpoint_dir,
      hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
             tf.contrib.training.StopAfterNEvalsHook(1)],
      eval_ops=image_write_ops,
      max_number_of_evaluations=FLAGS.max_number_of_evaluations) 
开发者ID:itsamitgoel,项目名称:Gun-Detector,代码行数:40,代码来源:eval.py

示例11: test_mnist_data_reading

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_mnist_data_reading(self):
    dataset_dir = os.path.join(
        flags.FLAGS.test_srcdir,
        'google3/third_party/tensorflow_models/gan/mnist/testdata')

    batch_size = 5
    images, labels, num_samples = data_provider.provide_data(
        'test', batch_size, dataset_dir)
    self.assertEqual(num_samples, 10000)

    with self.test_session() as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images, labels = sess.run([images, labels])
        self.assertEqual(images.shape, (batch_size, 28, 28, 1))
        self.assertEqual(labels.shape, (batch_size, 10)) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:17,代码来源:data_provider_test.py

示例12: test_provide_data

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def test_provide_data(self):
    images = data_provider.provide_data(
        'mnist',
        'train',
        dataset_dir=self.testdata_dir,
        batch_size=2,
        shuffle=False,
        patch_height=3,
        patch_width=3,
        colors=1)
    self.assertEqual(images.shape.as_list(), [2, 3, 3, 1])
    with self.test_session(use_gpu=True) as sess:
      with tf.contrib.slim.queues.QueueRunners(sess):
        images_np = sess.run(images)
    self.assertEqual(images_np.shape, (2, 3, 3, 1)) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:17,代码来源:data_provider_test.py

示例13: main

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, run_eval_loop=True):
  with tf.name_scope('inputs'):
    images = data_provider.provide_data(
        'validation', FLAGS.batch_size, dataset_dir=FLAGS.dataset_dir,
        patch_size=FLAGS.patch_size)

  # In order for variables to load, use the same variable scope as in the
  # train job.
  with tf.variable_scope('generator'):
    reconstructions, _, prebinary = networks.compression_model(
        images,
        num_bits=FLAGS.bits_per_patch,
        depth=FLAGS.model_depth,
        is_training=False)
  summaries.add_reconstruction_summaries(images, reconstructions, prebinary)

  # Visualize losses.
  pixel_loss_per_example = tf.reduce_mean(
      tf.abs(images - reconstructions), axis=[1, 2, 3])
  pixel_loss = tf.reduce_mean(pixel_loss_per_example)
  tf.summary.histogram('pixel_l1_loss_hist', pixel_loss_per_example)
  tf.summary.scalar('pixel_l1_loss', pixel_loss)

  # Create ops to write images to disk.
  uint8_images = data_provider.float_image_to_uint8(images)
  uint8_reconstructions = data_provider.float_image_to_uint8(reconstructions)
  uint8_reshaped = summaries.stack_images(uint8_images, uint8_reconstructions)
  image_write_ops = tf.write_file(
      '%s/%s'% (FLAGS.eval_dir, 'compression.png'),
      tf.image.encode_png(uint8_reshaped[0]))

  # For unit testing, use `run_eval_loop=False`.
  if not run_eval_loop: return
  tf.contrib.training.evaluate_repeatedly(
      FLAGS.checkpoint_dir,
      master=FLAGS.master,
      hooks=[tf.contrib.training.SummaryAtEndHook(FLAGS.eval_dir),
             tf.contrib.training.StopAfterNEvalsHook(1)],
      eval_ops=image_write_ops,
      max_number_of_evaluations=FLAGS.max_number_of_evaluations) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:42,代码来源:eval.py

示例14: main

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import provide_data [as 别名]
def main(_, override_generator_fn=None, override_discriminator_fn=None):
  # Create directories if not exist.
  if not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MakeDirs(FLAGS.output_dir)

  # Make sure steps integers are consistent.
  if FLAGS.max_number_of_steps % FLAGS.steps_per_eval != 0:
    raise ValueError('`max_number_of_steps` must be divisible by '
                     '`steps_per_eval`.')

  # Create optimizers.
  gen_opt, dis_opt = _get_optimizer(FLAGS.generator_lr, FLAGS.discriminator_lr)

  # Create estimator.
  # (joelshor): Add optional distribution strategy here.
  stargan_estimator = tfgan.estimator.StarGANEstimator(
      generator_fn=override_generator_fn or network.generator,
      discriminator_fn=override_discriminator_fn or network.discriminator,
      loss_fn=tfgan.stargan_loss,
      generator_optimizer=gen_opt,
      discriminator_optimizer=dis_opt,
      get_hooks_fn=tfgan.get_sequential_train_hooks(_define_train_step()),
      add_summaries=tfgan.estimator.SummaryType.IMAGES)

  # Get input function for training and test images.
  train_input_fn = lambda: data_provider.provide_data(  # pylint:disable=g-long-lambda
      FLAGS.image_file_patterns, FLAGS.batch_size, FLAGS.patch_size)
  test_images_np, _ = data_provider.provide_celeba_test_set()
  filename_str = os.path.join(FLAGS.output_dir, 'summary_image_%i.png')

  # Periodically train and write prediction output to disk.
  cur_step = 0
  while cur_step < FLAGS.max_number_of_steps:
    stargan_estimator.train(train_input_fn, steps=FLAGS.steps_per_eval)
    cur_step += FLAGS.steps_per_eval
    summary_img = _get_summary_image(stargan_estimator, test_images_np)
    _write_to_disk(summary_img, filename_str % cur_step) 
开发者ID:generalized-iou,项目名称:g-tensorflow-models,代码行数:39,代码来源:train.py


注:本文中的data_provider.provide_data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。