当前位置: 首页>>代码示例>>Python>>正文


Python dataset_factory.get_dataset方法代码示例

本文整理汇总了Python中datasets.dataset_factory.get_dataset方法的典型用法代码示例。如果您正苦于以下问题:Python dataset_factory.get_dataset方法的具体用法?Python dataset_factory.get_dataset怎么用?Python dataset_factory.get_dataset使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在datasets.dataset_factory的用法示例。


在下文中一共展示了dataset_factory.get_dataset方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    if hasattr(network_fn, 'default_image_size'):
      image_size = network_fn.default_image_size
    else:
      image_size = FLAGS.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:ringringyi,项目名称:DOTA_models,代码行数:23,代码来源:export_inference_graph.py

示例2: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:yuantailing,项目名称:ctw-baseline,代码行数:21,代码来源:export_inference_graph.py

示例3: config_initialization

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def config_initialization():
    # image shape and feature layers shape inference
    image_shape = (FLAGS.eval_image_height, FLAGS.eval_image_width)
    
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')
    tf.logging.set_verbosity(tf.logging.DEBUG)
    
    config.init_config(image_shape, 
                       batch_size = 1, 
                       seg_conf_threshold = FLAGS.seg_conf_threshold,
                       link_conf_threshold = FLAGS.link_conf_threshold, 
                       train_with_ignored = FLAGS.train_with_ignored,
                       seg_loc_loss_weight = FLAGS.seg_loc_loss_weight, 
                       link_cls_loss_weight = FLAGS.link_cls_loss_weight, 
                       )
        
    
    util.proc.set_proc_name('eval_' + FLAGS.model_name + '_' + FLAGS.dataset_name )
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
    config.print_config(FLAGS, dataset, print_to_file = False)
    
    return dataset 
开发者ID:dengdan,项目名称:seglink,代码行数:25,代码来源:eval_seglink.py

示例4: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)

    if FLAGS.quantize:
      tf.contrib.quantize.create_eval_graph()

    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:andrewekhalel,项目名称:edafa,代码行数:25,代码来源:export_inference_graph.py

示例5: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    if hasattr(network_fn, 'default_image_size'):
      image_size = network_fn.default_image_size
    else:
      image_size = FLAGS.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:anthonyhu,项目名称:tumblr-emotions,代码行数:23,代码来源:export_inference_graph.py

示例6: _select_dataset

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def _select_dataset(self):
    """Selects and returns the dataset used for training/eval.

    :return: One ore more slim.dataset.Dataset.
    """
    dataset = super(GanModel, self)._select_dataset()
    if FLAGS.unpaired_target_dataset_name:
      target_dataset = dataset_factory.get_dataset(
        FLAGS.unpaired_target_dataset_name, FLAGS.dataset_split_name, FLAGS.unpaired_target_dataset_dir)
      return (dataset, target_dataset)
    else:
      return dataset

  ######################
  # Select the network #
  ###################### 
开发者ID:jerryli27,项目名称:TwinGAN,代码行数:18,代码来源:image_generation.py

示例7: _select_dataset

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def _select_dataset(self):
    """Selects and returns the dataset used for training/eval.

    :return: One ore more slim.dataset.Dataset.
    """
    dataset = dataset_factory.get_dataset(
      FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
    assert dataset.num_samples >= FLAGS.batch_size
    self.num_samples = dataset.num_samples
    if hasattr(dataset, 'num_classes'):
      self.num_classes = dataset.num_classes
    else:
      self.num_classes = 0
    tf.logging.info('dataset %s number of classes:%d ,number of samples:%d'
                    % (FLAGS.dataset_name, self.num_classes, self.num_samples))
    return dataset

  ######################
  # Select the network #
  ###################### 
开发者ID:jerryli27,项目名称:TwinGAN,代码行数:22,代码来源:model_inheritor.py

示例8: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[1, image_size, image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:loicmarie,项目名称:hands-detection,代码行数:20,代码来源:export_inference_graph.py

示例9: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        # num_classes=(dataset.num_classes - FLAGS.labels_offset),
        num_classes=5,
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString()) 
开发者ID:yeephycho,项目名称:nasnet-tensorflow,代码行数:22,代码来源:export_inference_graph.py

示例10: imagenet_input

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def imagenet_input(is_training):
  """Data reader for imagenet.

  Reads in imagenet data and performs pre-processing on the images.

  Args:
     is_training: bool specifying if train or validation dataset is needed.
  Returns:
     A batch of images and labels.
  """
  if is_training:
    dataset = dataset_factory.get_dataset('imagenet', 'train',
                                          FLAGS.dataset_dir)
  else:
    dataset = dataset_factory.get_dataset('imagenet', 'validation',
                                          FLAGS.dataset_dir)

  provider = slim.dataset_data_provider.DatasetDataProvider(
      dataset,
      shuffle=is_training,
      common_queue_capacity=2 * FLAGS.batch_size,
      common_queue_min=FLAGS.batch_size)
  [image, label] = provider.get(['image', 'label'])

  image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      'mobilenet_v1', is_training=is_training)

  image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)

  images, labels = tf.train.batch(
      tensors=[image, label],
      batch_size=FLAGS.batch_size,
      num_threads=4,
      capacity=5 * FLAGS.batch_size)
  return images, labels 
开发者ID:leimao,项目名称:DeepLab_v3,代码行数:37,代码来源:mobilenet_v1_eval.py

示例11: imagenet_input

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def imagenet_input(is_training):
  """Data reader for imagenet.

  Reads in imagenet data and performs pre-processing on the images.

  Args:
     is_training: bool specifying if train or validation dataset is needed.
  Returns:
     A batch of images and labels.
  """
  if is_training:
    dataset = dataset_factory.get_dataset('imagenet', 'train',
                                          FLAGS.dataset_dir)
  else:
    dataset = dataset_factory.get_dataset('imagenet', 'validation',
                                          FLAGS.dataset_dir)

  provider = slim.dataset_data_provider.DatasetDataProvider(
      dataset,
      shuffle=is_training,
      common_queue_capacity=2 * FLAGS.batch_size,
      common_queue_min=FLAGS.batch_size)
  [image, label] = provider.get(['image', 'label'])

  image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      'mobilenet_v1', is_training=is_training)

  image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)

  images, labels = tf.train.batch(
      [image, label],
      batch_size=FLAGS.batch_size,
      num_threads=4,
      capacity=5 * FLAGS.batch_size)
  labels = slim.one_hot_encoding(labels, FLAGS.num_classes)
  return images, labels 
开发者ID:leimao,项目名称:DeepLab_v3,代码行数:38,代码来源:mobilenet_v1_train.py

示例12: config_initialization

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def config_initialization():
    # image shape and feature layers shape inference
    image_shape = (FLAGS.train_image_height, FLAGS.train_image_width)
    
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')
    tf.logging.set_verbosity(tf.logging.DEBUG)
    util.init_logger(log_file = 'log_train_seglink_%d_%d.log'%image_shape, log_path = FLAGS.train_dir, stdout = False, mode = 'a')
    
    
    config.init_config(image_shape, 
                       batch_size = FLAGS.batch_size, 
                       weight_decay = FLAGS.weight_decay, 
                       num_gpus = FLAGS.num_gpus, 
                       train_with_ignored = FLAGS.train_with_ignored,
                       seg_loc_loss_weight = FLAGS.seg_loc_loss_weight, 
                       link_cls_loss_weight = FLAGS.link_cls_loss_weight, 
                       )

    batch_size = config.batch_size
    batch_size_per_gpu = config.batch_size_per_gpu
        
    tf.summary.scalar('batch_size', batch_size)
    tf.summary.scalar('batch_size_per_gpu', batch_size_per_gpu)

    util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name)
    
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
    config.print_config(FLAGS, dataset)
    return dataset 
开发者ID:dengdan,项目名称:seglink,代码行数:32,代码来源:train_seglink.py

示例13: config_initialization

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def config_initialization():
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')
    tf.logging.set_verbosity(tf.logging.DEBUG)
    
    # image shape and feature layers shape inference
    image_shape = (FLAGS.train_image_height, FLAGS.train_image_width)
    
    config.init_config(image_shape, batch_size = FLAGS.batch_size)

    util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name)
    
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
#     config.print_config(FLAGS, dataset)
    return dataset 
开发者ID:dengdan,项目名称:seglink,代码行数:17,代码来源:test_batch_and_gt.py

示例14: main

# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  if FLAGS.is_video_model and not FLAGS.num_frames:
    raise ValueError(
        'Number of frames must be specified for video models with --num_frames')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    if FLAGS.is_video_model:
      input_shape = [FLAGS.batch_size, FLAGS.num_frames,
                     image_size, image_size, 3]
    else:
      input_shape = [FLAGS.batch_size, image_size, image_size, 3]
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=input_shape)
    network_fn(placeholder)

    if FLAGS.quantize:
      tf.contrib.quantize.create_eval_graph()

    graph_def = graph.as_graph_def()
    if FLAGS.write_text_graphdef:
      tf.io.write_graph(
          graph_def,
          os.path.dirname(FLAGS.output_file),
          os.path.basename(FLAGS.output_file),
          as_text=True)
    else:
      with gfile.GFile(FLAGS.output_file, 'wb') as f:
        f.write(graph_def.SerializeToString()) 
开发者ID:google-research,项目名称:morph-net,代码行数:39,代码来源:export_inference_graph.py


注:本文中的datasets.dataset_factory.get_dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。