当前位置: 首页>>代码示例>>Python>>正文


Python models.create方法代码示例

本文整理汇总了Python中models.create方法的典型用法代码示例。如果您正苦于以下问题:Python models.create方法的具体用法?Python models.create怎么用?Python models.create使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在models的用法示例。


在下文中一共展示了models.create方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: setup_dataset

# 需要导入模块: import models [as 别名]
# 或者: from models import create [as 别名]
def setup_dataset(mode, crop_dir, mask_dir=None, mean_mask_dir=None,
                  mean_grid_dir=None, trimap_dir=None, alpha_dir=None,
                  alpha_weight_dir=None):
    # Create dataset
    dataset = datasets.create(mode, crop_dir, mask_dir, mean_mask_dir,
                              mean_grid_dir, trimap_dir, alpha_dir,
                              alpha_weight_dir)

    # Create transform function
    transform = transforms.create(mode)
    transform_random = transforms.transform_random

    # Split into train and test
    train_raw, test_raw = datasets.split_dataset(dataset)

    # Increase data variety
    train_raw = chainer.datasets.TransformDataset(train_raw, transform_random)

    # Transform for network inputs
    train = chainer.datasets.TransformDataset(train_raw, transform)
    test = chainer.datasets.TransformDataset(test_raw, transform)

    return train, test 
开发者ID:takiyu,项目名称:portrait_matting,代码行数:25,代码来源:train.py

示例2: setup_model

# 需要导入模块: import models [as 别名]
# 或者: from models import create [as 别名]
def setup_model(mode, pretrained_path=None, pretrained_n_input_ch=2,
                pretrained_n_output_ch=21, mat_scale=4):
    # Create empty model
    model = models.create(mode, mat_scale=mat_scale)

    # Copy from pretrained model
    if pretrained_path is not None:
        if mode == 'seg' or mode == 'seg+' or mode == 'seg_tri' or \
           mode == 'mat':
            # FCN8s
            logger.info('Load pretrained FCN8s model (%s)', pretrained_path)
            pretrained = models.FCN8s(n_input_ch=pretrained_n_input_ch,
                                      n_output_ch=pretrained_n_output_ch)
            chainer.serializers.load_npz(pretrained_path, pretrained)
            model.init_from_fcn8s(pretrained)

        else:
            logger.error('Unknown mode')

    return model 
开发者ID:takiyu,项目名称:portrait_matting,代码行数:22,代码来源:train.py

示例3: Model2Feature

# 需要导入模块: import models [as 别名]
# 或者: from models import create [as 别名]
def Model2Feature(data, net, checkpoint, dim=512, width=224, root=None, nThreads=16, batch_size=100, pool_feature=False, **kargs):
    dataset_name = data
    model = models.create(net, dim=dim, pretrained=False)
    # resume = load_checkpoint(ckp_path)
    resume = checkpoint
    model.load_state_dict(resume['state_dict'])
    model = torch.nn.DataParallel(model).cuda()
    data = DataSet.create(data, width=width, root=root)
    
    if dataset_name in ['shop', 'jd_test']:
        gallery_loader = torch.utils.data.DataLoader(
            data.gallery, batch_size=batch_size, shuffle=False,
            drop_last=False, pin_memory=True, num_workers=nThreads)

        query_loader = torch.utils.data.DataLoader(
            data.query, batch_size=batch_size,
            shuffle=False, drop_last=False,
            pin_memory=True, num_workers=nThreads)

        gallery_feature, gallery_labels = extract_features(model, gallery_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        query_feature, query_labels = extract_features(model, query_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)

    else:
        data_loader = torch.utils.data.DataLoader(
            data.gallery, batch_size=batch_size,
            shuffle=False, drop_last=False, pin_memory=True,
            num_workers=nThreads)
        features, labels = extract_features(model, data_loader, print_freq=1e5, metric=None, pool_feature=pool_feature)
        gallery_feature, gallery_labels = query_feature, query_labels = features, labels
    return gallery_feature, gallery_labels, query_feature, query_labels 
开发者ID:bnu-wangxun,项目名称:Deep_Metric,代码行数:32,代码来源:Model2Feature.py

示例4: main

# 需要导入模块: import models [as 别名]
# 或者: from models import create [as 别名]
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    sys.stdout = Logger(osp.join(args.save_dir, 'log_' + args.dataset + '.txt'))

    if use_gpu:
        print("Currently using GPU: {}".format(args.gpu))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU")

    print("Creating dataset: {}".format(args.dataset))
    dataset = datasets.create(
        name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu,
        num_workers=args.workers,
    )

    trainloader, testloader = dataset.trainloader, dataset.testloader

    print("Creating model: {}".format(args.model))
    model = models.create(name=args.model, num_classes=dataset.num_classes)

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    criterion_xent = nn.CrossEntropyLoss()
    criterion_cent = CenterLoss(num_classes=dataset.num_classes, feat_dim=2, use_gpu=use_gpu)
    optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=5e-04, momentum=0.9)
    optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)

    if args.stepsize > 0:
        scheduler = lr_scheduler.StepLR(optimizer_model, step_size=args.stepsize, gamma=args.gamma)

    start_time = time.time()

    for epoch in range(args.max_epoch):
        print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
        train(model, criterion_xent, criterion_cent,
              optimizer_model, optimizer_centloss,
              trainloader, use_gpu, dataset.num_classes, epoch)

        if args.stepsize > 0: scheduler.step()

        if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch:
            print("==> Test")
            acc, err = test(model, testloader, use_gpu, dataset.num_classes, epoch)
            print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 
开发者ID:KaiyangZhou,项目名称:pytorch-center-loss,代码行数:57,代码来源:main.py


注:本文中的models.create方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。