本文整理汇总了Python中datasets.create_dataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.create_dataset方法的具体用法?Python datasets.create_dataset怎么用?Python datasets.create_dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类datasets
的用法示例。
在下文中一共展示了datasets.create_dataset方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __get_input
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def __get_input(self):
preprocessor = COCOPreprocessor(
batch_size=self.args.batch_size,
output_shapes=[[self.args.batch_size, self.args.input_size, self.args.input_size, 3]],
num_splits=1,
dtype=tf.float32,
train=False,
distortions=True,
resize_method=None,
shift_ratio=0
)
class params:
datasets_repeat_cached_sample = False
self.params = params()
self.dataset = datasets.create_dataset(self.args.data_location, 'coco')
return preprocessor.minibatch(
self.dataset,
subset='validation',
params=self.params,
shift_ratio=0)
示例2: main
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def main(extra_flags):
# Check no unknown flags was passed.
assert len(extra_flags) >= 1
if len(extra_flags) > 1:
raise ValueError('Received unknown flags: %s' % extra_flags[1:])
# Get parameters from FLAGS passed.
params = parameters.make_params_from_flags()
deploy.setup_env(params)
parameters.save_params(params, params.train_dir)
# TF log...
tfversion = deploy.tensorflow_version_tuple()
deploy.log_fn('TensorFlow: %i.%i' % (tfversion[0], tfversion[1]))
# Create model and dataset.
dataset = datasets.create_dataset(
params.data_dir, params.data_name, params.data_subset)
model = models.create_model(params.model, dataset)
set_model_params(model, params)
# Run CNN trainer.
trainer = deploy.TrainerCNN(dataset, model, params)
trainer.print_info()
trainer.run()
示例3: test_CRUD_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def test_CRUD_dataset(capsys, crud_dataset_id):
datasets.create_dataset(
project_id,
cloud_region,
crud_dataset_id)
datasets.get_dataset(
project_id, cloud_region, crud_dataset_id)
datasets.list_datasets(
project_id, cloud_region)
datasets.delete_dataset(
project_id, cloud_region, crud_dataset_id)
out, _ = capsys.readouterr()
# Check that create/get/list/delete worked
assert 'Created dataset' in out
assert 'Time zone' in out
assert 'Dataset' in out
assert 'Deleted dataset' in out
示例4: test_CRUD_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def test_CRUD_dataset(capsys):
datasets.create_dataset(
service_account_json,
project_id,
cloud_region,
dataset_id)
datasets.get_dataset(
service_account_json, project_id, cloud_region, dataset_id)
datasets.list_datasets(
service_account_json, project_id, cloud_region)
# Test and also clean up
datasets.delete_dataset(
service_account_json, project_id, cloud_region, dataset_id)
out, _ = capsys.readouterr()
# Check that create/get/list/delete worked
assert 'Created dataset' in out
assert 'Time zone' in out
assert 'Dataset' in out
assert 'Deleted dataset' in out
示例5: test_patch_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def test_patch_dataset(capsys):
datasets.create_dataset(
service_account_json,
project_id,
cloud_region,
dataset_id)
datasets.patch_dataset(
service_account_json,
project_id,
cloud_region,
dataset_id,
time_zone)
# Clean up
datasets.delete_dataset(
service_account_json, project_id, cloud_region, dataset_id)
out, _ = capsys.readouterr()
# Check that the patch to the time zone worked
assert 'UTC' in out
示例6: main
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def main(extra_flags):
# Check no unknown flags was passed.
assert len(extra_flags) >= 1
if len(extra_flags) > 1:
raise ValueError('Received unknown flags: %s' % extra_flags[1:])
# Get parameters from FLAGS passed.
params = parameters.make_params_from_flags()
deploy.setup_env(params)
# Training parameters, update using json file.
params = replace_with_train_params(params)
# TF log...
tfversion = deploy.tensorflow_version_tuple()
deploy.log_fn('TensorFlow: %i.%i' % (tfversion[0], tfversion[1]))
# Create model and dataset.
dataset = datasets.create_dataset(
params.data_dir, params.data_name, params.data_subset)
model = models.create_model(params.model, dataset)
train.set_model_params(model, params)
# Set the number of batches to the size of the eval dataset.
params = params._replace(
num_batches=int(dataset.num_examples_per_epoch() / (params.batch_size * params.num_gpus)))
# Run CNN trainer.
trainer = deploy.TrainerCNN(dataset, model, params)
trainer.print_info()
trainer.run()
示例7: test_deidentify_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def test_deidentify_dataset(capsys):
datasets.create_dataset(
service_account_json,
project_id,
cloud_region,
dataset_id)
datasets.deidentify_dataset(
service_account_json,
project_id,
cloud_region,
dataset_id,
destination_dataset_id,
keeplist_tags)
# Clean up
datasets.delete_dataset(
service_account_json, project_id, cloud_region, dataset_id)
datasets.delete_dataset(
service_account_json,
project_id,
cloud_region,
destination_dataset_id)
out, _ = capsys.readouterr()
# Check that de-identify worked
assert 'De-identified data written to' in out
示例8: __init__
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def __init__(self, params):
"""Initialize BenchmarkCNN.
Args:
params: Params tuple, typically created by make_params or
make_params_from_flags.
Raises:
ValueError: Unsupported params settings.
"""
self.params = params
if FLAGS.deterministic:
assert self.params.data_dir is None
self.dataset = datasets.create_dataset(None,
self.params.data_name)
else:
self.dataset = datasets.create_dataset(self.params.data_dir,
self.params.data_name)
self.model = model_config.get_model_config(self.params.model,
self.dataset)
self.data_format = self.params.data_format
self.resize_method = self.params.resize_method
self.use_synthetic_gpu_images = self.dataset.use_synthetic_gpu_images()
self.num_batches_for_eval = self.params.num_batches_for_eval
if ((self.params.num_epochs_per_decay or
self.params.learning_rate_decay_factor) and
not (
self.params.learning_rate and self.params.num_epochs_per_decay and
self.params.learning_rate_decay_factor)):
raise ValueError('If one of num_epochs_per_decay or '
'learning_rate_decay_factor is set, both must be set'
'and learning_rate must be set')
if (self.params.minimum_learning_rate and
not (
self.params.learning_rate and self.params.num_epochs_per_decay and
self.params.learning_rate_decay_factor)):
raise ValueError('minimum_learning_rate requires learning_rate,'
'num_epochs_per_decay, and '
'learning_rate_decay_factor to be set')
# Use the batch size from the command line if specified, otherwise use the
# model's default batch size. Scale the benchmark's batch size by the
# number of GPUs.
if self.params.batch_size > 0:
self.model.set_batch_size(self.params.batch_size)
self.batch_size = self.model.get_batch_size()
self.batch_group_size = self.params.batch_group_size
self.loss_scale = None
self.loss_scale_normal_steps = None
self.image_preprocessor = self.get_image_preprocessor()
示例9: test_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def test_dataset():
@retry(
wait_exponential_multiplier=1000,
wait_exponential_max=10000,
stop_max_attempt_number=10,
retry_on_exception=retry_if_server_exception)
def create():
try:
datasets.create_dataset(project_id, cloud_region, dataset_id)
except HttpError as err:
# We ignore 409 conflict here, because we know it's most
# likely the first request failed on the client side, but
# the creation suceeded on the server side.
if err.resp.status == 409:
print(
'Got exception {} while creating dataset'.format(
err.resp.status))
else:
raise
create()
yield
# Clean up
@retry(
wait_exponential_multiplier=1000,
wait_exponential_max=10000,
stop_max_attempt_number=10,
retry_on_exception=retry_if_server_exception)
def clean_up():
try:
datasets.delete_dataset(project_id, cloud_region, dataset_id)
except HttpError as err:
# The API returns 403 when the dataset doesn't exist.
if err.resp.status == 404 or err.resp.status == 403:
print(
'Got exception {} while deleting dataset'.format(
err.resp.status))
else:
raise
clean_up()
示例10: init_fn
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import create_dataset [as 别名]
def init_fn(self):
# create training dataset
self.train_ds = create_dataset(self.options.dataset, self.options)
# create Mesh object
self.mesh = Mesh()
self.faces = self.mesh.faces.to(self.device)
# create GraphCNN
self.graph_cnn = GraphCNN(self.mesh.adjmat,
self.mesh.ref_vertices.t(),
num_channels=self.options.num_channels,
num_layers=self.options.num_layers
).to(self.device)
# SMPL Parameter regressor
self.smpl_param_regressor = SMPLParamRegressor().to(self.device)
# Setup a joint optimizer for the 2 models
self.optimizer = torch.optim.Adam(params=list(self.graph_cnn.parameters()) + list(self.smpl_param_regressor.parameters()),
lr=self.options.lr,
betas=(self.options.adam_beta1, 0.999),
weight_decay=self.options.wd)
# SMPL model
self.smpl = SMPL().to(self.device)
# Create loss functions
self.criterion_shape = nn.L1Loss().to(self.device)
self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)
self.criterion_regr = nn.MSELoss().to(self.device)
# Pack models and optimizers in a dict - necessary for checkpointing
self.models_dict = {'graph_cnn': self.graph_cnn, 'smpl_param_regressor': self.smpl_param_regressor}
self.optimizers_dict = {'optimizer': self.optimizer}
# Renderer for visualization
self.renderer = Renderer(faces=self.smpl.faces.cpu().numpy())
# LSP indices from full list of keypoints
self.to_lsp = list(range(14))
# Optionally start training from a pretrained checkpoint
# Note that this is different from resuming training
# For the latter use --resume
if self.options.pretrained_checkpoint is not None:
self.load_pretrained(checkpoint_file=self.options.pretrained_checkpoint)