當前位置: 首頁>>代碼示例>>Python>>正文


Python models.create方法代碼示例

本文整理匯總了Python中models.create方法的典型用法代碼示例。如果您正苦於以下問題:Python models.create方法的具體用法?Python models.create怎麽用?Python models.create使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在models的用法示例。


在下文中一共展示了models.create方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: setup_dataset

# 需要導入模塊: import models [as 別名]
# 或者: from models import create [as 別名]
def setup_dataset(mode, crop_dir, mask_dir=None, mean_mask_dir=None,
                  mean_grid_dir=None, trimap_dir=None, alpha_dir=None,
                  alpha_weight_dir=None):
    # Create dataset
    dataset = datasets.create(mode, crop_dir, mask_dir, mean_mask_dir,
                              mean_grid_dir, trimap_dir, alpha_dir,
                              alpha_weight_dir)

    # Create transform function
    transform = transforms.create(mode)
    transform_random = transforms.transform_random

    # Split into train and test
    train_raw, test_raw = datasets.split_dataset(dataset)

    # Increase data variety
    train_raw = chainer.datasets.TransformDataset(train_raw, transform_random)

    # Transform for network inputs
    train = chainer.datasets.TransformDataset(train_raw, transform)
    test = chainer.datasets.TransformDataset(test_raw, transform)

    return train, test 
開發者ID:takiyu,項目名稱:portrait_matting,代碼行數:25,代碼來源:train.py

示例2: setup_model

# 需要導入模塊: import models [as 別名]
# 或者: from models import create [as 別名]
def setup_model(mode, pretrained_path=None, pretrained_n_input_ch=2,
                pretrained_n_output_ch=21, mat_scale=4):
    # Create empty model
    model = models.create(mode, mat_scale=mat_scale)

    # Copy from pretrained model
    if pretrained_path is not None:
        if mode == 'seg' or mode == 'seg+' or mode == 'seg_tri' or \
           mode == 'mat':
            # FCN8s
            logger.info('Load pretrained FCN8s model (%s)', pretrained_path)
            pretrained = models.FCN8s(n_input_ch=pretrained_n_input_ch,
                                      n_output_ch=pretrained_n_output_ch)
            chainer.serializers.load_npz(pretrained_path, pretrained)
            model.init_from_fcn8s(pretrained)

        else:
            logger.error('Unknown mode')

    return model 
開發者ID:takiyu,項目名稱:portrait_matting,代碼行數:22,代碼來源:train.py

示例3: Model2Feature

# 需要導入模塊: import models [as 別名]
# 或者: from models import create [as 別名]
def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, nThreads=16, batch_size=100, pool_feature=False, **kargs):
    dataset_name = data
    model = models.create(net, dim=dim, pretrained=False)
    # resume = load_checkpoint(ckp_path)
    resume = checkpoint
    model.load_state_dict(resume['state_dict'])
    model = torch.nn.DataParallel(model).cuda()
    data = DataSet.create(data, width=width, root=root)
    
    if dataset_name in ['shop', 'jd_test']:
        gallery_loader = torch.utils.data.DataLoader(
            data.gallery, batch_size=batch_size, shuffle=False,
            drop_last=False, pin_memory=True, num_workers=nThreads)

        query_loader = torch.utils.data.DataLoader(
            data.query, batch_size=batch_size,
            shuffle=False, drop_last=False,
            pin_memory=True, num_workers=nThreads)

        gallery_feature, gallery_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        query_feature, query_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)

    else:
        data_loader = torch.utils.data.DataLoader(
            data.gallery, batch_size=batch_size,
            shuffle=False, drop_last=False, pin_memory=True,
            num_workers=nThreads)
        features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    return gallery_feature, gallery_labels, query_feature, query_labels 
開發者ID:bnu-wangxun,項目名稱:Deep_Metric,代碼行數:32,代碼來源:Model2Feature.py

示例4: main

# 需要導入模塊: import models [as 別名]
# 或者: from models import create [as 別名]
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    sys.stdout = Logger(osp.join(args.save_dir, 'log_' + args.dataset + '.txt'))

    if use_gpu:
        print("Currently using GPU: {}".format(args.gpu))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU")

    print("Creating dataset: {}".format(args.dataset))
    dataset = datasets.create(
        name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu,
        num_workers=args.workers,
    )

    trainloader, testloader = dataset.trainloader, dataset.testloader

    print("Creating model: {}".format(args.model))
    model = models.create(name=args.model, num_classes=dataset.num_classes)

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    criterion_xent = nn.CrossEntropyLoss()
    criterion_cent = CenterLoss(num_classes=dataset.num_classes, feat_dim=2, use_gpu=use_gpu)
    optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=5e-04, momentum=0.9)
    optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)

    if args.stepsize > 0:
        scheduler = lr_scheduler.StepLR(optimizer_model, step_size=args.stepsize, gamma=args.gamma)

    start_time = time.time()

    for epoch in range(args.max_epoch):
        print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
        train(model, criterion_xent, criterion_cent,
              optimizer_model, optimizer_centloss,
              trainloader, use_gpu, dataset.num_classes, epoch)

        if args.stepsize > 0: scheduler.step()

        if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch:
            print("==> Test")
            acc, err = test(model, testloader, use_gpu, dataset.num_classes, epoch)
            print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 
開發者ID:KaiyangZhou,項目名稱:pytorch-center-loss,代碼行數:57,代碼來源:main.py


注:本文中的models.create方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。