当前位置: 首页>>代码示例>>Python>>正文


Python datasets.get_dataset方法代码示例

本文整理汇总了Python中datasets.get_dataset方法的典型用法代码示例。如果您正苦于以下问题:Python datasets.get_dataset方法的具体用法?Python datasets.get_dataset怎么用?Python datasets.get_dataset使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在datasets的用法示例。


在下文中一共展示了datasets.get_dataset方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_CRUD_dataset

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def test_CRUD_dataset(capsys, crud_dataset_id):
    datasets.create_dataset(
        project_id,
        cloud_region,
        crud_dataset_id)

    datasets.get_dataset(
        project_id, cloud_region, crud_dataset_id)

    datasets.list_datasets(
        project_id, cloud_region)

    datasets.delete_dataset(
        project_id, cloud_region, crud_dataset_id)

    out, _ = capsys.readouterr()

    # Check that create/get/list/delete worked
    assert 'Created dataset' in out
    assert 'Time zone' in out
    assert 'Dataset' in out
    assert 'Deleted dataset' in out 
开发者ID:GoogleCloudPlatform,项目名称:python-docs-samples,代码行数:24,代码来源:datasets_test.py

示例2: test_CRUD_dataset

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def test_CRUD_dataset(capsys):
    datasets.create_dataset(
        service_account_json,
        project_id,
        cloud_region,
        dataset_id)

    datasets.get_dataset(
        service_account_json, project_id, cloud_region, dataset_id)

    datasets.list_datasets(
        service_account_json, project_id, cloud_region)

    # Test and also clean up
    datasets.delete_dataset(
        service_account_json, project_id, cloud_region, dataset_id)

    out, _ = capsys.readouterr()

    # Check that create/get/list/delete worked
    assert 'Created dataset' in out
    assert 'Time zone' in out
    assert 'Dataset' in out
    assert 'Deleted dataset' in out 
开发者ID:GoogleCloudPlatform,项目名称:python-docs-samples,代码行数:26,代码来源:datasets_test.py

示例3: extract_reg_feat

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def extract_reg_feat(config):
    """Extract regional features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'reg'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('reg_model')(config['pretrained']['reg_model'], **(config['reg_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            reg_f = h5py.File(dump_path, 'a')
            if 'reg_feat' not in reg_f or config['reg_feat']['overwrite']:
                reg_feat = model.run_test_data(data['image'])
                if 'reg_feat' in reg_f:
                    del reg_f['reg_feat']
                _ = reg_f.create_dataset('reg_feat', data=reg_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
开发者ID:luigifreda,项目名称:pyslam,代码行数:27,代码来源:evaluations.py

示例4: format_data

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def format_data(config):
    """Post-processing and generate custom files."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'post_format'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    idx = 0
    while True:
        try:
            data = next(test_set)
            dataset.format_data(data)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break 
开发者ID:luigifreda,项目名称:pyslam,代码行数:19,代码来源:evaluations.py

示例5: evaluate_network

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def evaluate_network(network, dataset):
        """Spawn a training sessions.

        Args:
            network (dict): The JSON definition of the network
            dataset (string): The name of the dataset to use
        """
        # Get the dataset.
        _, batch_size, _, x_train, x_test, y_train, y_test = get_dataset(dataset)

        model = model_from_json(network)
        model.compile(loss='categorical_crossentropy', optimizer='adam',
                      metrics=['accuracy'])

        model.fit(x_train, y_train,
                  batch_size=batch_size,
                  epochs=10000,  # essentially infinite, uses early stopping
                  verbose=1,
                  validation_data=(x_test, y_test),
                  callbacks=[early_stopper])

        score = model.evaluate(x_test, y_test, verbose=0)

        metrics = {'loss': score[0], 'accuracy': score[1]}

        return metrics 
开发者ID:harvitronix,项目名称:super-simple-distributed-keras,代码行数:28,代码来源:worker.py

示例6: main

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def main(argv=None):  # pylint: disable=unused-argument
    assert args.detect or args.segment, "Either detect or segment should be True"
    if args.trunk == 'resnet50':
        net = ResNet
        depth = 50
    if args.trunk == 'vgg16':
        net = VGG
        depth = 16

    net = net(config=net_config, depth=depth, training=True, weight_decay=args.weight_decay)

    if args.dataset == 'voc07':
        dataset = get_dataset('voc07_trainval')
    if args.dataset == 'voc12-trainval':
        dataset = get_dataset('voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc12-train':
        dataset = get_dataset('voc12-train-segmentation')
    if args.dataset == 'voc12-val':
        dataset = get_dataset('voc12-val-segmentation')
    if args.dataset == 'voc07+12':
        dataset = get_dataset('voc07_trainval', 'voc12_train', 'voc12_val')
    if args.dataset == 'voc07+12-segfull':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc07+12-segmentation':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation')
    if args.dataset == 'coco':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-train2014-*', 'coco-valminusminival2014-*')
    if args.dataset == 'coco-seg':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-seg-train2014-*', 'coco-seg-valminusminival2014-*')

    train(dataset, net, net_config) 
开发者ID:dvornikita,项目名称:blitznet,代码行数:35,代码来源:training.py

示例7: main

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def main():

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)

    _, test_loader, _ = build_dataset(dataset=dataset,
                            batch_size=args.batch_size,
                            input_dir=args.input_dir)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer()
   
    model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
    model = model.to(torch_device)
    model, _ = mixed_precision.initialize(model, None)

    test_stats = AverageMeterSet()
    test(model, test_loader, torch_device, test_stats)
    stat_str = test_stats.pretty_string(ignore=model.tasks)
    print(stat_str) 
开发者ID:Philip-Bachman,项目名称:amdim-public,代码行数:29,代码来源:test.py

示例8: extract_loc_feat

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def extract_loc_feat(config):
    """Extract local features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'loc'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('loc_model')(config['pretrained']['loc_model'], **(config['loc_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            loc_f = h5py.File(dump_path, 'a')
            if 'loc_info' not in loc_f and 'kpt' not in loc_f or config['loc_feat']['overwrite']:
                # detect SIFT keypoints and crop image patches.
                loc_feat, kpt_mb, npy_kpts, cv_kpts, _ = model.run_test_data(data['image'])
                loc_info = np.concatenate((npy_kpts, loc_feat, kpt_mb), axis=-1)
                raw_kpts = [np.array((i.pt[0], i.pt[1], i.size, i.angle, i.response))
                            for i in cv_kpts]
                raw_kpts = np.stack(raw_kpts, axis=0)
                loc_info = np.concatenate((raw_kpts, loc_info), axis=-1)
                if 'loc_info' in loc_f or 'kpt' in loc_f:
                    del loc_f['loc_info']
                _ = loc_f.create_dataset('loc_info', data=loc_info)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
开发者ID:luigifreda,项目名称:pyslam,代码行数:33,代码来源:evaluations.py

示例9: extract_aug_feat

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def extract_aug_feat(config):
    """Extract augmented features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'aug'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('aug_model')(config['pretrained']['loc_model'], **(config['aug_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            aug_f = h5py.File(dump_path, 'a')
            if 'aug_feat' not in aug_f or config['aug_feat']['overwrite']:
                aug_feat, _ = model.run_test_data(data['dump_data'])
                if 'aug_feat' in aug_f:
                    del aug_f['aug_feat']
                if aug_feat.dtype == np.uint8:
                    _ = aug_f.create_dataset('aug_feat', data=aug_feat, dtype='uint8')
                else:
                    _ = aug_f.create_dataset('aug_feat', data=aug_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
开发者ID:luigifreda,项目名称:pyslam,代码行数:30,代码来源:evaluations.py

示例10: main

# 需要导入模块: import datasets [as 别名]
# 或者: from datasets import get_dataset [as 别名]
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args.output_dir, args.run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args.batch_size,
                      input_dir=args.input_dir,
                      labeled_only=args.classifiers)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args.output_dir)
    if args.cpt_load_path:
        model = checkpointer.restore_model_from_checkpoint(
                    args.cpt_load_path, 
                    training_classifier=args.classifiers)
    else:
        # create new model with random parameters
        model = Model(ndf=args.ndf, n_classes=num_classes, n_rkhs=args.n_rkhs,
                    tclip=args.tclip, n_depth=args.n_depth, encoder_size=encoder_size,
                    use_bn=(args.use_bn == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)


    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args.classifiers else train_self_supervised
    task(model, args.learning_rate, dataset, train_loader,
         test_loader, stat_tracker, checkpointer, args.output_dir, torch_device) 
开发者ID:Philip-Bachman,项目名称:amdim-public,代码行数:51,代码来源:train.py


注:本文中的datasets.get_dataset方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。