当前位置: 首页>>代码示例>>Python>>正文


Python data_provider.get_data方法代码示例

本文整理汇总了Python中data_provider.get_data方法的典型用法代码示例。如果您正苦于以下问题:Python data_provider.get_data方法的具体用法?Python data_provider.get_data怎么用?Python data_provider.get_data使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在data_provider的用法示例。


在下文中一共展示了data_provider.get_data方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_provided_data_has_correct_shape

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def test_provided_data_has_correct_shape(self):
    batch_size = 4
    data = data_provider.get_data(
        dataset=datasets.fsns_test.get_test_split(),
        batch_size=batch_size,
        augment=True,
        central_crop_size=None)

    with self.test_session() as sess, queues.QueueRunners(sess):
      images_np, labels_np = sess.run([data.images, data.labels_one_hot])

    self.assertEqual(images_np.shape, (batch_size, 150, 600, 3))
    self.assertEqual(labels_np.shape, (batch_size, 37, 134)) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:15,代码来源:data_provider_test.py

示例2: test_optionally_applies_central_crop

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def test_optionally_applies_central_crop(self):
    batch_size = 4
    data = data_provider.get_data(
        dataset=datasets.fsns_test.get_test_split(),
        batch_size=batch_size,
        augment=True,
        central_crop_size=(500, 100))

    with self.test_session() as sess, queues.QueueRunners(sess):
      images_np = sess.run(data.images)

    self.assertEqual(images_np.shape, (batch_size, 100, 500, 3)) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:14,代码来源:data_provider_test.py

示例3: main

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def main(_):
  if not tf.gfile.Exists(FLAGS.eval_log_dir):
    tf.gfile.MakeDirs(FLAGS.eval_log_dir)

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())
  endpoints = model.create_base(data.images, labels_one_hot=None)
  model.create_loss(data, endpoints)
  eval_ops = model.create_summaries(
      data, endpoints, dataset.charset, is_training=False)
  slim.get_or_create_global_step()
  session_config = tf.ConfigProto(device_count={"GPU": 0})
  slim.evaluation.evaluation_loop(
      master=FLAGS.master,
      checkpoint_dir=FLAGS.train_log_dir,
      logdir=FLAGS.eval_log_dir,
      eval_op=eval_ops,
      num_evals=FLAGS.num_batches,
      eval_interval_secs=FLAGS.eval_interval_secs,
      max_number_of_evaluations=FLAGS.number_of_steps,
      session_config=session_config) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:30,代码来源:eval.py

示例4: main

# 需要导入模块: import data_provider [as 别名]
# 或者: from data_provider import get_data [as 别名]
def main(_):
  prepare_training_dir()

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  hparams = get_training_hparams()

  # If ps_tasks is zero, the local device is used. When using multiple
  # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
  # across the different devices.
  device_setter = tf.train.replica_device_setter(
      FLAGS.ps_tasks, merge_devices=True)
  with tf.device(device_setter):
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=hparams.use_augment_input,
        central_crop_size=common_flags.get_crop_size())
    endpoints = model.create_base(data.images, data.labels_one_hot)
    total_loss = model.create_loss(data, endpoints)
    model.create_summaries(data, endpoints, dataset.charset, is_training=True)
    init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                              FLAGS.checkpoint_inception)
    if FLAGS.show_graph_stats:
      logging.info('Total number of weights in the graph: %s',
                   calculate_graph_metrics())
    train(total_loss, init_fn, hparams) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:31,代码来源:train.py


注:本文中的data_provider.get_data方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。