本文整理汇总了Python中datasets.dataset_factory.get_dataset方法的典型用法代码示例。如果您正苦于以下问题:Python dataset_factory.get_dataset方法的具体用法?Python dataset_factory.get_dataset怎么用?Python dataset_factory.get_dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类datasets.dataset_factory
的用法示例。
在下文中一共展示了dataset_factory.get_dataset方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, 'default_image_size'):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[1, image_size, image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
示例2: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[FLAGS.batch_size, image_size,
image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
示例3: config_initialization
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def config_initialization():
# image shape and feature layers shape inference
image_shape = (FLAGS.eval_image_height, FLAGS.eval_image_width)
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.DEBUG)
config.init_config(image_shape,
batch_size = 1,
seg_conf_threshold = FLAGS.seg_conf_threshold,
link_conf_threshold = FLAGS.link_conf_threshold,
train_with_ignored = FLAGS.train_with_ignored,
seg_loc_loss_weight = FLAGS.seg_loc_loss_weight,
link_cls_loss_weight = FLAGS.link_cls_loss_weight,
)
util.proc.set_proc_name('eval_' + FLAGS.model_name + '_' + FLAGS.dataset_name )
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
config.print_config(FLAGS, dataset, print_to_file = False)
return dataset
示例4: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[FLAGS.batch_size, image_size,
image_size, 3])
network_fn(placeholder)
if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph()
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
示例5: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, 'default_image_size'):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[1, image_size, image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
示例6: _select_dataset
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def _select_dataset(self):
"""Selects and returns the dataset used for training/eval.
:return: One ore more slim.dataset.Dataset.
"""
dataset = super(GanModel, self)._select_dataset()
if FLAGS.unpaired_target_dataset_name:
target_dataset = dataset_factory.get_dataset(
FLAGS.unpaired_target_dataset_name, FLAGS.dataset_split_name, FLAGS.unpaired_target_dataset_dir)
return (dataset, target_dataset)
else:
return dataset
######################
# Select the network #
######################
示例7: _select_dataset
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def _select_dataset(self):
"""Selects and returns the dataset used for training/eval.
:return: One ore more slim.dataset.Dataset.
"""
dataset = dataset_factory.get_dataset(
FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
assert dataset.num_samples >= FLAGS.batch_size
self.num_samples = dataset.num_samples
if hasattr(dataset, 'num_classes'):
self.num_classes = dataset.num_classes
else:
self.num_classes = 0
tf.logging.info('dataset %s number of classes:%d ,number of samples:%d'
% (FLAGS.dataset_name, self.num_classes, self.num_samples))
return dataset
######################
# Select the network #
######################
示例8: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[1, image_size, image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
示例9: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
# num_classes=(dataset.num_classes - FLAGS.labels_offset),
num_classes=5,
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[FLAGS.batch_size, image_size,
image_size, 3])
network_fn(placeholder)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
示例10: imagenet_input
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def imagenet_input(is_training):
"""Data reader for imagenet.
Reads in imagenet data and performs pre-processing on the images.
Args:
is_training: bool specifying if train or validation dataset is needed.
Returns:
A batch of images and labels.
"""
if is_training:
dataset = dataset_factory.get_dataset('imagenet', 'train',
FLAGS.dataset_dir)
else:
dataset = dataset_factory.get_dataset('imagenet', 'validation',
FLAGS.dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
'mobilenet_v1', is_training=is_training)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.batch(
tensors=[image, label],
batch_size=FLAGS.batch_size,
num_threads=4,
capacity=5 * FLAGS.batch_size)
return images, labels
示例11: imagenet_input
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def imagenet_input(is_training):
"""Data reader for imagenet.
Reads in imagenet data and performs pre-processing on the images.
Args:
is_training: bool specifying if train or validation dataset is needed.
Returns:
A batch of images and labels.
"""
if is_training:
dataset = dataset_factory.get_dataset('imagenet', 'train',
FLAGS.dataset_dir)
else:
dataset = dataset_factory.get_dataset('imagenet', 'validation',
FLAGS.dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
'mobilenet_v1', is_training=is_training)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=4,
capacity=5 * FLAGS.batch_size)
labels = slim.one_hot_encoding(labels, FLAGS.num_classes)
return images, labels
示例12: config_initialization
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def config_initialization():
# image shape and feature layers shape inference
image_shape = (FLAGS.train_image_height, FLAGS.train_image_width)
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.DEBUG)
util.init_logger(log_file = 'log_train_seglink_%d_%d.log'%image_shape, log_path = FLAGS.train_dir, stdout = False, mode = 'a')
config.init_config(image_shape,
batch_size = FLAGS.batch_size,
weight_decay = FLAGS.weight_decay,
num_gpus = FLAGS.num_gpus,
train_with_ignored = FLAGS.train_with_ignored,
seg_loc_loss_weight = FLAGS.seg_loc_loss_weight,
link_cls_loss_weight = FLAGS.link_cls_loss_weight,
)
batch_size = config.batch_size
batch_size_per_gpu = config.batch_size_per_gpu
tf.summary.scalar('batch_size', batch_size)
tf.summary.scalar('batch_size_per_gpu', batch_size_per_gpu)
util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name)
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
config.print_config(FLAGS, dataset)
return dataset
示例13: config_initialization
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def config_initialization():
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.DEBUG)
# image shape and feature layers shape inference
image_shape = (FLAGS.train_image_height, FLAGS.train_image_width)
config.init_config(image_shape, batch_size = FLAGS.batch_size)
util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name)
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
# config.print_config(FLAGS, dataset)
return dataset
示例14: main
# 需要导入模块: from datasets import dataset_factory [as 别名]
# 或者: from datasets.dataset_factory import get_dataset [as 别名]
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
if FLAGS.is_video_model and not FLAGS.num_frames:
raise ValueError(
'Number of frames must be specified for video models with --num_frames')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
FLAGS.dataset_dir)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
if FLAGS.is_video_model:
input_shape = [FLAGS.batch_size, FLAGS.num_frames,
image_size, image_size, 3]
else:
input_shape = [FLAGS.batch_size, image_size, image_size, 3]
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=input_shape)
network_fn(placeholder)
if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph()
graph_def = graph.as_graph_def()
if FLAGS.write_text_graphdef:
tf.io.write_graph(
graph_def,
os.path.dirname(FLAGS.output_file),
os.path.basename(FLAGS.output_file),
as_text=True)
else:
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())