当前位置: 首页>>代码示例>>Python>>正文


Python models.VGG属性代码示例

本文整理汇总了Python中models.VGG属性的典型用法代码示例。如果您正苦于以下问题:Python models.VGG属性的具体用法?Python models.VGG怎么用?Python models.VGG使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在models的用法示例。


在下文中一共展示了models.VGG属性的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_model

# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def get_model(device):
    """
    :param device: instance of torch.device
    :return: An instance of torch.nn.Module
    """
    num_classes = 2
    if config["dataset"] == "Cifar100":
        num_classes = 100
    elif config["dataset"] == "Cifar10":
        num_classes = 10

    model = {
        "vgg11": lambda: models.VGG("VGG11", num_classes, batch_norm=False),
        "vgg11_bn": lambda: models.VGG("VGG11", num_classes, batch_norm=True),
        "vgg13": lambda: models.VGG("VGG13", num_classes, batch_norm=False),
        "vgg13_bn": lambda: models.VGG("VGG13", num_classes, batch_norm=True),
        "vgg16": lambda: models.VGG("VGG16", num_classes, batch_norm=False),
        "vgg16_bn": lambda: models.VGG("VGG16", num_classes, batch_norm=True),
        "vgg19": lambda: models.VGG("VGG19", num_classes, batch_norm=False),
        "vgg19_bn": lambda: models.VGG("VGG19", num_classes, batch_norm=True),
        "resnet10": lambda: models.ResNet10(num_classes=num_classes),
        "resnet18": lambda: models.ResNet18(num_classes=num_classes),
        "resnet34": lambda: models.ResNet34(num_classes=num_classes),
        "resnet50": lambda: models.ResNet50(num_classes=num_classes),
        "resnet101": lambda: models.ResNet101(num_classes=num_classes),
        "resnet152": lambda: models.ResNet152(num_classes=num_classes),
        "bert": lambda: models.BertImage(config, num_classes=num_classes),
    }[config["model"]]()

    model.to(device)
    if device == torch.device("cuda"):
        print("Use DataParallel if multi-GPU")
        model = torch.nn.DataParallel(model)
        torch.backends.cudnn.benchmark = True

    return model 
开发者ID:epfml,项目名称:attention-cnn,代码行数:38,代码来源:train.py

示例2: main

# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def main(args):
    dataroot = Path(args.dataroot)
    save_dir = dataroot / 'map'
    save_dir.mkdir(exist_ok=True)

    dataset = SwappingDataset(
        dataroot=dataroot, input_size=40 if 'CUFED' in dataroot.name else 80)
    dataloader = DataLoader(dataset)
    model = VGG(model_type='vgg19').to(device)
    swapper = Swapper(args.patch_size, args.stride).to(device)

    for i, batch in enumerate(tqdm(dataloader), 1):
        img_in = batch['img_in'].to(device)
        img_ref = batch['img_ref'].to(device)
        img_ref_blur = batch['img_ref_blur'].to(device)

        map_in = model(img_in, TARGET_LAYERS)
        map_ref = model(img_ref, TARGET_LAYERS)
        map_ref_blur = model(img_ref_blur, TARGET_LAYERS)

        maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)

        np.savez_compressed(save_dir / f'{batch["filename"][0]}.npz',
                            relu1_1=maps['relu1_1'],
                            relu2_1=maps['relu2_1'],
                            relu3_1=maps['relu3_1'],
                            weights=weights,
                            correspondences=correspondences)

        if args.debug and i == 10:
            break 
开发者ID:S-aiueo32,项目名称:srntt-pytorch,代码行数:33,代码来源:offline_texture_swapping.py

示例3: main

# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def main(args):
    imgs = load_image(args.input, args.ref)

    vgg = VGG(model_type='vgg19').to(device)
    swapper = Swapper().to(device)

    map_in = vgg(imgs['bic'].to(device), TARGET_LAYERS)
    map_ref = vgg(imgs['ref'].to(device), TARGET_LAYERS)
    map_ref_blur = vgg(imgs['ref_blur'].to(device), TARGET_LAYERS)

    with torch.no_grad(), timer('Feature swapping'):
        maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)

    model = SRNTT(use_weights=args.use_weights).to(device)
    model.load_state_dict(torch.load(args.weight))

    img_hr = imgs['hr'].to(device)
    img_lr = imgs['lr'].to(device)
    maps = {
        k: torch.tensor(v).unsqueeze(0).to(device) for k, v in maps.items()}
    weights = torch.tensor(weights).reshape(1, 1, *weights.shape).to(device)

    with torch.no_grad(), timer('Inference'):
        _, img_sr = model(img_lr, maps, weights)

    psnr = PSNR()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
    ssim = SSIM()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
    print(f'[Result] PSNR:{psnr:.2f}, SSIM:{ssim:.4f}')

    save_image(img_sr.clamp(0, 1), './out.png') 
开发者ID:S-aiueo32,项目名称:srntt-pytorch,代码行数:32,代码来源:online_inference.py

示例4: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def __init__(self, use_weights=False):
        super(TextureLoss, self).__init__()
        self.use_weights = use_weights

        self.model = VGG(model_type='vgg19')
        self.register_buffer('a', torch.tensor(-20., requires_grad=False))
        self.register_buffer('b', torch.tensor(.65, requires_grad=False)) 
开发者ID:S-aiueo32,项目名称:srntt-pytorch,代码行数:9,代码来源:texture_loss.py

示例5: __init__

# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def __init__(self,
                 model_type: str = 'vgg19',
                 target_layer: str = 'relu5_1',
                 norm_type: str = 'fro'):
        super(PerceptualLoss, self).__init__()

        assert norm_type in ['mse', 'fro']

        self.model = VGG(model_type=model_type)
        self.target_layer = target_layer
        self.norm_type = norm_type 
开发者ID:S-aiueo32,项目名称:srntt-pytorch,代码行数:13,代码来源:perceptual_loss.py


注:本文中的models.VGG属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。