當前位置: 首頁>>代碼示例>>Python>>正文


Python data_provider.get_data方法代碼示例

本文整理匯總了Python中data_provider.get_data方法的典型用法代碼示例。如果您正苦於以下問題:Python data_provider.get_data方法的具體用法?Python data_provider.get_data怎麽用?Python data_provider.get_data使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在data_provider的用法示例。


在下文中一共展示了data_provider.get_data方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_provided_data_has_correct_shape

# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def test_provided_data_has_correct_shape(self):
    batch_size = 4
    data = data_provider.get_data(
        dataset=datasets.fsns_test.get_test_split(),
        batch_size=batch_size,
        augment=True,
        central_crop_size=None)

    with self.test_session() as sess, queues.QueueRunners(sess):
      images_np, labels_np = sess.run([data.images, data.labels_one_hot])

    self.assertEqual(images_np.shape, (batch_size, 150, 600, 3))
    self.assertEqual(labels_np.shape, (batch_size, 37, 134)) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:15,代碼來源:data_provider_test.py

示例2: test_optionally_applies_central_crop

# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def test_optionally_applies_central_crop(self):
    batch_size = 4
    data = data_provider.get_data(
        dataset=datasets.fsns_test.get_test_split(),
        batch_size=batch_size,
        augment=True,
        central_crop_size=(500, 100))

    with self.test_session() as sess, queues.QueueRunners(sess):
      images_np = sess.run(data.images)

    self.assertEqual(images_np.shape, (batch_size, 100, 500, 3)) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:14,代碼來源:data_provider_test.py

示例3: main

# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def main(_):
  if not tf.gfile.Exists(FLAGS.eval_log_dir):
    tf.gfile.MakeDirs(FLAGS.eval_log_dir)

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  data = data_provider.get_data(
      dataset,
      FLAGS.batch_size,
      augment=False,
      central_crop_size=common_flags.get_crop_size())
  endpoints = model.create_base(data.images, labels_one_hot=None)
  model.create_loss(data, endpoints)
  eval_ops = model.create_summaries(
      data, endpoints, dataset.charset, is_training=False)
  slim.get_or_create_global_step()
  session_config = tf.ConfigProto(device_count={"GPU": 0})
  slim.evaluation.evaluation_loop(
      master=FLAGS.master,
      checkpoint_dir=FLAGS.train_log_dir,
      logdir=FLAGS.eval_log_dir,
      eval_op=eval_ops,
      num_evals=FLAGS.num_batches,
      eval_interval_secs=FLAGS.eval_interval_secs,
      max_number_of_evaluations=FLAGS.number_of_steps,
      session_config=session_config) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:30,代碼來源:eval.py

示例4: main

# 需要導入模塊: import data_provider [as 別名]
# 或者: from data_provider import get_data [as 別名]
def main(_):
  prepare_training_dir()

  dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
  model = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
  hparams = get_training_hparams()

  # If ps_tasks is zero, the local device is used. When using multiple
  # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
  # across the different devices.
  device_setter = tf.train.replica_device_setter(
      FLAGS.ps_tasks, merge_devices=True)
  with tf.device(device_setter):
    data = data_provider.get_data(
        dataset,
        FLAGS.batch_size,
        augment=hparams.use_augment_input,
        central_crop_size=common_flags.get_crop_size())
    endpoints = model.create_base(data.images, data.labels_one_hot)
    total_loss = model.create_loss(data, endpoints)
    model.create_summaries(data, endpoints, dataset.charset, is_training=True)
    init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
                                              FLAGS.checkpoint_inception)
    if FLAGS.show_graph_stats:
      logging.info('Total number of weights in the graph: %s',
                   calculate_graph_metrics())
    train(total_loss, init_fn, hparams) 
開發者ID:ringringyi,項目名稱:DOTA_models,代碼行數:31,代碼來源:train.py


注:本文中的data_provider.get_data方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。