當前位置: 首頁>>代碼示例>>Python>>正文


Python datasets.get_dataset方法代碼示例

本文整理匯總了Python中datasets.get_dataset方法的典型用法代碼示例。如果您正苦於以下問題:Python datasets.get_dataset方法的具體用法?Python datasets.get_dataset怎麽用?Python datasets.get_dataset使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在datasets的用法示例。


在下文中一共展示了datasets.get_dataset方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: test_CRUD_dataset

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def test_CRUD_dataset(capsys, crud_dataset_id):
    datasets.create_dataset(
        project_id,
        cloud_region,
        crud_dataset_id)

    datasets.get_dataset(
        project_id, cloud_region, crud_dataset_id)

    datasets.list_datasets(
        project_id, cloud_region)

    datasets.delete_dataset(
        project_id, cloud_region, crud_dataset_id)

    out, _ = capsys.readouterr()

    # Check that create/get/list/delete worked
    assert 'Created dataset' in out
    assert 'Time zone' in out
    assert 'Dataset' in out
    assert 'Deleted dataset' in out 
開發者ID:GoogleCloudPlatform,項目名稱:python-docs-samples,代碼行數:24,代碼來源:datasets_test.py

示例2: test_CRUD_dataset

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def test_CRUD_dataset(capsys):
    datasets.create_dataset(
        service_account_json,
        project_id,
        cloud_region,
        dataset_id)

    datasets.get_dataset(
        service_account_json, project_id, cloud_region, dataset_id)

    datasets.list_datasets(
        service_account_json, project_id, cloud_region)

    # Test and also clean up
    datasets.delete_dataset(
        service_account_json, project_id, cloud_region, dataset_id)

    out, _ = capsys.readouterr()

    # Check that create/get/list/delete worked
    assert 'Created dataset' in out
    assert 'Time zone' in out
    assert 'Dataset' in out
    assert 'Deleted dataset' in out 
開發者ID:GoogleCloudPlatform,項目名稱:python-docs-samples,代碼行數:26,代碼來源:datasets_test.py

示例3: extract_reg_feat

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def extract_reg_feat(config):
    """Extract regional features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'reg'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('reg_model')(config['pretrained']['reg_model'], **(config['reg_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            reg_f = h5py.File(dump_path, 'a')
            if 'reg_feat' not in reg_f or config['reg_feat']['overwrite']:
                reg_feat = model.run_test_data(data['image'])
                if 'reg_feat' in reg_f:
                    del reg_f['reg_feat']
                _ = reg_f.create_dataset('reg_feat', data=reg_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
開發者ID:luigifreda,項目名稱:pyslam,代碼行數:27,代碼來源:evaluations.py

示例4: format_data

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def format_data(config):
    """Post-processing and generate custom files."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'post_format'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    idx = 0
    while True:
        try:
            data = next(test_set)
            dataset.format_data(data)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break 
開發者ID:luigifreda,項目名稱:pyslam,代碼行數:19,代碼來源:evaluations.py

示例5: evaluate_network

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def evaluate_network(network, dataset):
        """Spawn a training sessions.

        Args:
            network (dict): The JSON definition of the network
            dataset (string): The name of the dataset to use
        """
        # Get the dataset.
        _, batch_size, _, x_train, x_test, y_train, y_test = get_dataset(dataset)

        model = model_from_json(network)
        model.compile(loss='categorical_crossentropy', optimizer='adam',
                      metrics=['accuracy'])

        model.fit(x_train, y_train,
                  batch_size=batch_size,
                  epochs=10000,  # essentially infinite, uses early stopping
                  verbose=1,
                  validation_data=(x_test, y_test),
                  callbacks=[early_stopper])

        score = model.evaluate(x_test, y_test, verbose=0)

        metrics = {'loss': score[0], 'accuracy': score[1]}

        return metrics 
開發者ID:harvitronix,項目名稱:super-simple-distributed-keras,代碼行數:28,代碼來源:worker.py

示例6: main

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def main(argv=None):  # pylint: disable=unused-argument
    assert args.detect or args.segment, "Either detect or segment should be True"
    if args.trunk == 'resnet50':
        net = ResNet
        depth = 50
    if args.trunk == 'vgg16':
        net = VGG
        depth = 16

    net = net(config=net_config, depth=depth, training=True, weight_decay=args.weight_decay)

    if args.dataset == 'voc07':
        dataset = get_dataset('voc07_trainval')
    if args.dataset == 'voc12-trainval':
        dataset = get_dataset('voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc12-train':
        dataset = get_dataset('voc12-train-segmentation')
    if args.dataset == 'voc12-val':
        dataset = get_dataset('voc12-val-segmentation')
    if args.dataset == 'voc07+12':
        dataset = get_dataset('voc07_trainval', 'voc12_train', 'voc12_val')
    if args.dataset == 'voc07+12-segfull':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation', 'voc12-val')
    if args.dataset == 'voc07+12-segmentation':
        dataset = get_dataset('voc07-trainval-segmentation', 'voc12-train-segmentation')
    if args.dataset == 'coco':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-train2014-*', 'coco-valminusminival2014-*')
    if args.dataset == 'coco-seg':
        # support by default for coco trainval35k split
        dataset = get_dataset('coco-seg-train2014-*', 'coco-seg-valminusminival2014-*')

    train(dataset, net, net_config) 
開發者ID:dvornikita,項目名稱:blitznet,代碼行數:35,代碼來源:training.py

示例7: main

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def main():

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)

    _, test_loader, _ = build_dataset(dataset=dataset,
                            batch_size=args.batch_size,
                            input_dir=args.input_dir)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer()
   
    model = checkpointer.restore_model_from_checkpoint(args.checkpoint_path)
    model = model.to(torch_device)
    model, _ = mixed_precision.initialize(model, None)

    test_stats = AverageMeterSet()
    test(model, test_loader, torch_device, test_stats)
    stat_str = test_stats.pretty_string(ignore=model.tasks)
    print(stat_str) 
開發者ID:Philip-Bachman,項目名稱:amdim-public,代碼行數:29,代碼來源:test.py

示例8: extract_loc_feat

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def extract_loc_feat(config):
    """Extract local features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'loc'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('loc_model')(config['pretrained']['loc_model'], **(config['loc_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            loc_f = h5py.File(dump_path, 'a')
            if 'loc_info' not in loc_f and 'kpt' not in loc_f or config['loc_feat']['overwrite']:
                # detect SIFT keypoints and crop image patches.
                loc_feat, kpt_mb, npy_kpts, cv_kpts, _ = model.run_test_data(data['image'])
                loc_info = np.concatenate((npy_kpts, loc_feat, kpt_mb), axis=-1)
                raw_kpts = [np.array((i.pt[0], i.pt[1], i.size, i.angle, i.response))
                            for i in cv_kpts]
                raw_kpts = np.stack(raw_kpts, axis=0)
                loc_info = np.concatenate((raw_kpts, loc_info), axis=-1)
                if 'loc_info' in loc_f or 'kpt' in loc_f:
                    del loc_f['loc_info']
                _ = loc_f.create_dataset('loc_info', data=loc_info)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
開發者ID:luigifreda,項目名稱:pyslam,代碼行數:33,代碼來源:evaluations.py

示例9: extract_aug_feat

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def extract_aug_feat(config):
    """Extract augmented features."""
    prog_bar = progressbar.ProgressBar()
    config['stage'] = 'aug'
    dataset = get_dataset(config['data_name'])(**config)
    prog_bar.max_value = dataset.data_length
    test_set = dataset.get_test_set()

    model = get_model('aug_model')(config['pretrained']['loc_model'], **(config['aug_feat']))
    idx = 0
    while True:
        try:
            data = next(test_set)
            dump_path = data['dump_path'].decode('utf-8')
            aug_f = h5py.File(dump_path, 'a')
            if 'aug_feat' not in aug_f or config['aug_feat']['overwrite']:
                aug_feat, _ = model.run_test_data(data['dump_data'])
                if 'aug_feat' in aug_f:
                    del aug_f['aug_feat']
                if aug_feat.dtype == np.uint8:
                    _ = aug_f.create_dataset('aug_feat', data=aug_feat, dtype='uint8')
                else:
                    _ = aug_f.create_dataset('aug_feat', data=aug_feat)
            prog_bar.update(idx)
            idx += 1
        except dataset.end_set:
            break
    model.close() 
開發者ID:luigifreda,項目名稱:pyslam,代碼行數:30,代碼來源:evaluations.py

示例10: main

# 需要導入模塊: import datasets [as 別名]
# 或者: from datasets import get_dataset [as 別名]
def main():
    # create target output dir if it doesn't exist yet
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)

    # enable mixed-precision computation if desired
    if args.amp:
        mixed_precision.enable_mixed_precision()

    # set the RNG seeds (probably more hidden elsewhere...)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # get the dataset
    dataset = get_dataset(args.dataset)
    encoder_size = get_encoder_size(dataset)

    # get a helper object for tensorboard logging
    log_dir = os.path.join(args.output_dir, args.run_name)
    stat_tracker = StatTracker(log_dir=log_dir)

    # get dataloaders for training and testing
    train_loader, test_loader, num_classes = \
        build_dataset(dataset=dataset,
                      batch_size=args.batch_size,
                      input_dir=args.input_dir,
                      labeled_only=args.classifiers)

    torch_device = torch.device('cuda')
    checkpointer = Checkpointer(args.output_dir)
    if args.cpt_load_path:
        model = checkpointer.restore_model_from_checkpoint(
                    args.cpt_load_path, 
                    training_classifier=args.classifiers)
    else:
        # create new model with random parameters
        model = Model(ndf=args.ndf, n_classes=num_classes, n_rkhs=args.n_rkhs,
                    tclip=args.tclip, n_depth=args.n_depth, encoder_size=encoder_size,
                    use_bn=(args.use_bn == 1))
        model.init_weights(init_scale=1.0)
        checkpointer.track_new_model(model)


    model = model.to(torch_device)

    # select which type of training to do
    task = train_classifiers if args.classifiers else train_self_supervised
    task(model, args.learning_rate, dataset, train_loader,
         test_loader, stat_tracker, checkpointer, args.output_dir, torch_device) 
開發者ID:Philip-Bachman,項目名稱:amdim-public,代碼行數:51,代碼來源:train.py


注:本文中的datasets.get_dataset方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。