本文整理汇总了Python中datasets.get_dataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.get_dataset方法的具体用法?Python datasets.get_dataset怎么用?Python datasets.get_dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类datasets
的用法示例。
在下文中一共展示了datasets.get_dataset方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_CRUD_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def test_CRUD_dataset(capsys, crud_dataset_id):
datasets.create_dataset(
project_id,
cloud_region,
crud_dataset_id)
datasets.get_dataset(
project_id, cloud_region, crud_dataset_id)
datasets.list_datasets(
project_id, cloud_region)
datasets.delete_dataset(
project_id, cloud_region, crud_dataset_id)
out, _ = capsys.readouterr()
# Check that create/get/list/delete worked
assert 'Created dataset' in out
assert 'Time zone' in out
assert 'Dataset' in out
assert 'Deleted dataset' in out
示例2: test_CRUD_dataset
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def test_CRUD_dataset(capsys):
datasets.create_dataset(
service_account_json,
project_id,
cloud_region,
dataset_id)
datasets.get_dataset(
service_account_json, project_id, cloud_region, dataset_id)
datasets.list_datasets(
service_account_json, project_id, cloud_region)
# Test and also clean up
datasets.delete_dataset(
service_account_json, project_id, cloud_region, dataset_id)
out, _ = capsys.readouterr()
# Check that create/get/list/delete worked
assert 'Created dataset' in out
assert 'Time zone' in out
assert 'Dataset' in out
assert 'Deleted dataset' in out
示例3: extract_reg_feat
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def extract_reg_feat(config):
"""Extract regional features."""
prog_bar = progressbar.ProgressBar()
config['stage'] = 'reg'
dataset = get_dataset(config['data_name'])(**config)
prog_bar.max_value = dataset.data_length
test_set = dataset.get_test_set()
model = get_model('reg_model')(config['pretrained']['reg_model'], **(config['reg_feat']))
idx = 0
while True:
try:
data = next(test_set)
dump_path = data['dump_path'].decode('utf-8')
reg_f = h5py.File(dump_path, 'a')
if 'reg_feat' not in reg_f or config['reg_feat']['overwrite']:
reg_feat = model.run_test_data(data['image'])
if 'reg_feat' in reg_f:
del reg_f['reg_feat']
_ = reg_f.create_dataset('reg_feat', data=reg_feat)
prog_bar.update(idx)
idx += 1
except dataset.end_set:
break
model.close()
示例4: format_data
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def format_data(config):
"""Post-processing and generate custom files."""
prog_bar = progressbar.ProgressBar()
config['stage'] = 'post_format'
dataset = get_dataset(config['data_name'])(**config)
prog_bar.max_value = dataset.data_length
test_set = dataset.get_test_set()
idx = 0
while True:
try:
data = next(test_set)
dataset.format_data(data)
prog_bar.update(idx)
idx += 1
except dataset.end_set:
break
示例5: evaluate_network
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def evaluate_network(network, dataset):
"""Spawn a training sessions.
Args:
network (dict): The JSON definition of the network
dataset (string): The name of the dataset to use
"""
# Get the dataset.
_, batch_size, _, x_train, x_test, y_train, y_test = get_dataset(dataset)
model = model_from_json(network)
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=10000, # essentially infinite, uses early stopping
verbose=1,
validation_data=(x_test, y_test),
callbacks=[early_stopper])
score = model.evaluate(x_test, y_test, verbose=0)
metrics = {'loss': score[0], 'accuracy': score[1]}
return metrics
示例6: main
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def main(argv=None): # pylint: disable=unused-argument
assert args.detect or args.segment, "Either detect or segment should be True"
if args.trunk == 'resnet50':
net = ResNet
depth = 50
if args.trunk == 'vgg16':
net = VGG
depth = 16
net = net(config=net_config, depth=depth, training=True, weight_decay=args.weight_decay)
if args.dataset == 'voc07':
dataset = get_dataset('voc07_trainval')
if args.dataset == 'voc12-trainval':
dataset = get_dataset('voc12-train-segmentation', 'voc12-val')
if args.dataset == 'voc12-train':
dataset = get_dataset('voc12-train-segmentation')
if args.dataset == 'voc12-val':
dataset = get_dataset('voc12-val-segmentation')
if args.dataset == 'voc07+12':
dataset = get_dataset('voc07_trainval', 'voc12_train', 'voc12_val')
if args.dataset == 'voc07+12-segfull':
dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation', 'voc12-val')
if args.dataset == 'voc07+12-segmentation':
dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation')
if args.dataset == 'coco':
# support by default for coco trainval35k split
dataset = get_dataset('coco-train2014-*', 'coco-valminusminival2014-*')
if args.dataset == 'coco-seg':
# support by default for coco trainval35k split
dataset = get_dataset('coco-seg-train2014-*', 'coco-seg-valminusminival2014-*')
train(dataset, net, net_config)
示例7: main
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def main():
# enable mixed-precision computation if desired
if args.amp:
mixed_precision.enable_mixed_precision()
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# get the dataset
dataset = get_dataset(args.dataset)
_, test_loader, _ = build_dataset(dataset=dataset,
batch_size=args.batch_size,
input_dir=args.input_dir)
torch_device = torch.device('cuda')
checkpointer = Checkpointer()
model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
model = model.to(torch_device)
model, _ = mixed_precision.initialize(model, None)
test_stats = AverageMeterSet()
test(model, test_loader, torch_device, test_stats)
stat_str = test_stats.pretty_string(ignore=model.tasks)
print(stat_str)
示例8: extract_loc_feat
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def extract_loc_feat(config):
"""Extract local features."""
prog_bar = progressbar.ProgressBar()
config['stage'] = 'loc'
dataset = get_dataset(config['data_name'])(**config)
prog_bar.max_value = dataset.data_length
test_set = dataset.get_test_set()
model = get_model('loc_model')(config['pretrained']['loc_model'], **(config['loc_feat']))
idx = 0
while True:
try:
data = next(test_set)
dump_path = data['dump_path'].decode('utf-8')
loc_f = h5py.File(dump_path, 'a')
if 'loc_info' not in loc_f and 'kpt' not in loc_f or config['loc_feat']['overwrite']:
# detect SIFT keypoints and crop image patches.
loc_feat, kpt_mb, npy_kpts, cv_kpts, _ = model.run_test_data(data['image'])
loc_info = np.concatenate((npy_kpts, loc_feat, kpt_mb), axis=-1)
raw_kpts = [np.array((i.pt[0], i.pt[1], i.size, i.angle, i.response))
for i in cv_kpts]
raw_kpts = np.stack(raw_kpts, axis=0)
loc_info = np.concatenate((raw_kpts, loc_info), axis=-1)
if 'loc_info' in loc_f or 'kpt' in loc_f:
del loc_f['loc_info']
_ = loc_f.create_dataset('loc_info', data=loc_info)
prog_bar.update(idx)
idx += 1
except dataset.end_set:
break
model.close()
示例9: extract_aug_feat
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def extract_aug_feat(config):
"""Extract augmented features."""
prog_bar = progressbar.ProgressBar()
config['stage'] = 'aug'
dataset = get_dataset(config['data_name'])(**config)
prog_bar.max_value = dataset.data_length
test_set = dataset.get_test_set()
model = get_model('aug_model')(config['pretrained']['loc_model'], **(config['aug_feat']))
idx = 0
while True:
try:
data = next(test_set)
dump_path = data['dump_path'].decode('utf-8')
aug_f = h5py.File(dump_path, 'a')
if 'aug_feat' not in aug_f or config['aug_feat']['overwrite']:
aug_feat, _ = model.run_test_data(data['dump_data'])
if 'aug_feat' in aug_f:
del aug_f['aug_feat']
if aug_feat.dtype == np.uint8:
_ = aug_f.create_dataset('aug_feat', data=aug_feat, dtype='uint8')
else:
_ = aug_f.create_dataset('aug_feat', data=aug_feat)
prog_bar.update(idx)
idx += 1
except dataset.end_set:
break
model.close()
示例10: main
# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def main():
# create target output dir if it doesn't exist yet
if not os.path.isdir(args.output_dir):
os.mkdir(args.output_dir)
# enable mixed-precision computation if desired
if args.amp:
mixed_precision.enable_mixed_precision()
# set the RNG seeds (probably more hidden elsewhere...)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# get the dataset
dataset = get_dataset(args.dataset)
encoder_size = get_encoder_size(dataset)
# get a helper object for tensorboard logging
log_dir = os.path.join(args.output_dir, args.run_name)
stat_tracker = StatTracker(log_dir=log_dir)
# get dataloaders for training and testing
train_loader, test_loader, num_classes = \
build_dataset(dataset=dataset,
batch_size=args.batch_size,
input_dir=args.input_dir,
labeled_only=args.classifiers)
torch_device = torch.device('cuda')
checkpointer = Checkpointer(args.output_dir)
if args.cpt_load_path:
model = checkpointer.restore_model_from_checkpoint(
args.cpt_load_path,
training_classifier=args.classifiers)
else:
# create new model with random parameters
model = Model(ndf=args.ndf, n_classes=num_classes, n_rkhs=args.n_rkhs,
tclip=args.tclip, n_depth=args.n_depth, encoder_size=encoder_size,
use_bn=(args.use_bn == 1))
model.init_weights(init_scale=1.0)
checkpointer.track_new_model(model)
model = model.to(torch_device)
# select which type of training to do
task = train_classifiers if args.classifiers else train_self_supervised
task(model, args.learning_rate, dataset, train_loader,
test_loader, stat_tracker, checkpointer, args.output_dir, torch_device)