当前位置: 首页>>代码示例>>Python>>正文


Python cifar10.Cifar10DataSet方法代码示例

本文整理汇总了Python中cifar10.Cifar10DataSet方法的典型用法代码示例。如果您正苦于以下问题:Python cifar10.Cifar10DataSet方法的具体用法?Python cifar10.Cifar10DataSet怎么用?Python cifar10.Cifar10DataSet使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在cifar10的用法示例。


在下文中一共展示了cifar10.Cifar10DataSet方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: input_fn

# 需要导入模块: import cifar10 [as 别名]
# 或者: from cifar10 import Cifar10DataSet [as 别名]
def input_fn(subset, num_shards):
  """Create input graph for model.

  Args:
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
  if subset == 'train':
    batch_size = FLAGS.train_batch_size
  elif subset == 'validate' or subset == 'eval':
    batch_size = FLAGS.eval_batch_size
  else:
    raise ValueError('Subset must be one of \'train\', \'validate\' and \'eval\'')
  with tf.device('/cpu:0'):
    use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
    dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion)
    image_batch, label_batch = dataset.make_batch(batch_size)
    if num_shards <= 1:
      # No GPU available or only 1 GPU.
      return [image_batch], [label_batch]

    # Note that passing num=batch_size is safe here, even though
    # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
    # examples. This is because it does so only when repeating for a limited
    # number of epochs, but our dataset repeats forever.
    image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
    label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
    feature_shards = [[] for i in range(num_shards)]
    label_shards = [[] for i in range(num_shards)]
    for i in xrange(batch_size):
      idx = i % num_shards
      feature_shards[idx].append(image_batch[i])
      label_shards[idx].append(label_batch[i])
    feature_shards = [tf.parallel_stack(x) for x in feature_shards]
    label_shards = [tf.parallel_stack(x) for x in label_shards]
    return feature_shards, label_shards 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:40,代码来源:cifar10_main.py

示例2: input_fn

# 需要导入模块: import cifar10 [as 别名]
# 或者: from cifar10 import Cifar10DataSet [as 别名]
def input_fn(data_dir,
             subset,
             num_shards,
             batch_size,
             use_distortion_for_training=True):
  """Create input graph for model.

  Args:
    data_dir: Directory where TFRecords representing the dataset are located.
    subset: one of 'train', 'validate' and 'eval'.
    num_shards: num of towers participating in data-parallel training.
    batch_size: total batch size for training to be divided by the number of
    shards.
    use_distortion_for_training: True to use distortions.
  Returns:
    two lists of tensors for features and labels, each of num_shards length.
  """
  with tf.device('/cpu:0'):
    use_distortion = subset == 'train' and use_distortion_for_training
    dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
    image_batch, label_batch = dataset.make_batch(batch_size)
    if num_shards <= 1:
      # No GPU available or only 1 GPU.
      return [image_batch], [label_batch]

    # Note that passing num=batch_size is safe here, even though
    # dataset.batch(batch_size) can, in some cases, return fewer than batch_size
    # examples. This is because it does so only when repeating for a limited
    # number of epochs, but our dataset repeats forever.
    image_batch = tf.unstack(image_batch, num=batch_size, axis=0)
    label_batch = tf.unstack(label_batch, num=batch_size, axis=0)
    feature_shards = [[] for i in range(num_shards)]
    label_shards = [[] for i in range(num_shards)]
    for i in xrange(batch_size):
      idx = i % num_shards
      feature_shards[idx].append(image_batch[i])
      label_shards[idx].append(label_batch[i])
    feature_shards = [tf.parallel_stack(x) for x in feature_shards]
    label_shards = [tf.parallel_stack(x) for x in label_shards]
    return feature_shards, label_shards 
开发者ID:rky0930,项目名称:yolo_v2,代码行数:42,代码来源:cifar10_main.py

示例3: main

# 需要导入模块: import cifar10 [as 别名]
# 或者: from cifar10 import Cifar10DataSet [as 别名]
def main(unused_argv):
  # The env variable is on deprecation path, default is set to off.
  os.environ['TF_SYNC_ON_FINISH'] = '0'

  if FLAGS.num_gpus < 0:
    raise ValueError(
        'Invalid GPU count: \"num_gpus\" must be 0 or a positive integer.')
  if FLAGS.num_gpus == 0 and not FLAGS.is_cpu_ps:
    raise ValueError(
        'No GPU available for use, must use CPU as parameter server.')
  if (FLAGS.num_layers - 2) % 6 != 0:
    raise ValueError('Invalid num_layers parameter.')
  if FLAGS.num_gpus != 0 and FLAGS.train_batch_size % FLAGS.num_gpus != 0:
    raise ValueError('train_batch_size must be multiple of num_gpus.')
  if FLAGS.num_gpus != 0 and FLAGS.eval_batch_size % FLAGS.num_gpus != 0:
    raise ValueError('eval_batch_size must be multiple of num_gpus.')

  num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval')
  if num_eval_examples % FLAGS.eval_batch_size != 0:
    raise ValueError('validation set size must be multiple of eval_batch_size')

  config = tf.estimator.RunConfig()
  sess_config = tf.ConfigProto()
  sess_config.allow_soft_placement = True
  sess_config.log_device_placement = FLAGS.log_device_placement
  sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
  sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
  sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
  config = config.replace(session_config=sess_config)

  classifier = tf.estimator.Estimator(
      model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)

  tensors_to_log = {'learning_rate': 'learning_rate'}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=100)

  print('Starting to train...')
  classifier.train(
      input_fn=functools.partial(
          input_fn, subset='train', num_shards=FLAGS.num_gpus),
      steps=FLAGS.train_steps,
      hooks=[logging_hook])

  print('Starting to evaluate...')
  eval_results = classifier.evaluate(
      input_fn=functools.partial(
          input_fn, subset='eval', num_shards=FLAGS.num_gpus),
      steps=num_eval_examples // FLAGS.eval_batch_size)
  print(eval_results) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:52,代码来源:cifar10_main.py


注:本文中的cifar10.Cifar10DataSet方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。