當前位置: 首頁>>代碼示例>>Python>>正文


Python models.VGG屬性代碼示例

本文整理匯總了Python中models.VGG屬性的典型用法代碼示例。如果您正苦於以下問題:Python models.VGG屬性的具體用法?Python models.VGG怎麽用?Python models.VGG使用的例子?那麽, 這裏精選的屬性代碼示例或許可以為您提供幫助。您也可以進一步了解該屬性所在models的用法示例。


在下文中一共展示了models.VGG屬性的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_model

# 需要導入模塊: import models [as 別名]
# 或者: from models import VGG [as 別名]
def get_model(device):
    """
    :param device: instance of torch.device
    :return: An instance of torch.nn.Module
    """
    num_classes = 2
    if config["dataset"] == "Cifar100":
        num_classes = 100
    elif config["dataset"] == "Cifar10":
        num_classes = 10

    model = {
        "vgg11": lambda: models.VGG("VGG11", num_classes, batch_norm=False),
        "vgg11_bn": lambda: models.VGG("VGG11", num_classes, batch_norm=True),
        "vgg13": lambda: models.VGG("VGG13", num_classes, batch_norm=False),
        "vgg13_bn": lambda: models.VGG("VGG13", num_classes, batch_norm=True),
        "vgg16": lambda: models.VGG("VGG16", num_classes, batch_norm=False),
        "vgg16_bn": lambda: models.VGG("VGG16", num_classes, batch_norm=True),
        "vgg19": lambda: models.VGG("VGG19", num_classes, batch_norm=False),
        "vgg19_bn": lambda: models.VGG("VGG19", num_classes, batch_norm=True),
        "resnet10": lambda: models.ResNet10(num_classes=num_classes),
        "resnet18": lambda: models.ResNet18(num_classes=num_classes),
        "resnet34": lambda: models.ResNet34(num_classes=num_classes),
        "resnet50": lambda: models.ResNet50(num_classes=num_classes),
        "resnet101": lambda: models.ResNet101(num_classes=num_classes),
        "resnet152": lambda: models.ResNet152(num_classes=num_classes),
        "bert": lambda: models.BertImage(config, num_classes=num_classes),
    }[config["model"]]()

    model.to(device)
    if device == torch.device("cuda"):
        print("Use DataParallel if multi-GPU")
        model = torch.nn.DataParallel(model)
        torch.backends.cudnn.benchmark = True

    return model 
開發者ID:epfml,項目名稱:attention-cnn,代碼行數:38,代碼來源:train.py

示例2: main

# 需要導入模塊: import models [as 別名]
# 或者: from models import VGG [as 別名]
def main(args):
    dataroot = Path(args.dataroot)
    save_dir = dataroot / 'map'
    save_dir.mkdir(exist_ok=True)

    dataset = SwappingDataset(
        dataroot=dataroot, input_size=40 if 'CUFED' in dataroot.name else 80)
    dataloader = DataLoader(dataset)
    model = VGG(model_type='vgg19').to(device)
    swapper = Swapper(args.patch_size, args.stride).to(device)

    for i, batch in enumerate(tqdm(dataloader), 1):
        img_in = batch['img_in'].to(device)
        img_ref = batch['img_ref'].to(device)
        img_ref_blur = batch['img_ref_blur'].to(device)

        map_in = model(img_in, TARGET_LAYERS)
        map_ref = model(img_ref, TARGET_LAYERS)
        map_ref_blur = model(img_ref_blur, TARGET_LAYERS)

        maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)

        np.savez_compressed(save_dir / f'{batch["filename"][0]}.npz',
                            relu1_1=maps['relu1_1'],
                            relu2_1=maps['relu2_1'],
                            relu3_1=maps['relu3_1'],
                            weights=weights,
                            correspondences=correspondences)

        if args.debug and i == 10:
            break 
開發者ID:S-aiueo32,項目名稱:srntt-pytorch,代碼行數:33,代碼來源:offline_texture_swapping.py

示例3: main

# 需要導入模塊: import models [as 別名]
# 或者: from models import VGG [as 別名]
def main(args):
    imgs = load_image(args.input, args.ref)

    vgg = VGG(model_type='vgg19').to(device)
    swapper = Swapper().to(device)

    map_in = vgg(imgs['bic'].to(device), TARGET_LAYERS)
    map_ref = vgg(imgs['ref'].to(device), TARGET_LAYERS)
    map_ref_blur = vgg(imgs['ref_blur'].to(device), TARGET_LAYERS)

    with torch.no_grad(), timer('Feature swapping'):
        maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)

    model = SRNTT(use_weights=args.use_weights).to(device)
    model.load_state_dict(torch.load(args.weight))

    img_hr = imgs['hr'].to(device)
    img_lr = imgs['lr'].to(device)
    maps = {
        k: torch.tensor(v).unsqueeze(0).to(device) for k, v in maps.items()}
    weights = torch.tensor(weights).reshape(1, 1, *weights.shape).to(device)

    with torch.no_grad(), timer('Inference'):
        _, img_sr = model(img_lr, maps, weights)

    psnr = PSNR()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
    ssim = SSIM()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
    print(f'[Result] PSNR:{psnr:.2f}, SSIM:{ssim:.4f}')

    save_image(img_sr.clamp(0, 1), './out.png') 
開發者ID:S-aiueo32,項目名稱:srntt-pytorch,代碼行數:32,代碼來源:online_inference.py

示例4: __init__

# 需要導入模塊: import models [as 別名]
# 或者: from models import VGG [as 別名]
def __init__(self, use_weights=False):
        super(TextureLoss, self).__init__()
        self.use_weights = use_weights

        self.model = VGG(model_type='vgg19')
        self.register_buffer('a', torch.tensor(-20., requires_grad=False))
        self.register_buffer('b', torch.tensor(.65, requires_grad=False)) 
開發者ID:S-aiueo32,項目名稱:srntt-pytorch,代碼行數:9,代碼來源:texture_loss.py

示例5: __init__

# 需要導入模塊: import models [as 別名]
# 或者: from models import VGG [as 別名]
def __init__(self,
                 model_type: str = 'vgg19',
                 target_layer: str = 'relu5_1',
                 norm_type: str = 'fro'):
        super(PerceptualLoss, self).__init__()

        assert norm_type in ['mse', 'fro']

        self.model = VGG(model_type=model_type)
        self.target_layer = target_layer
        self.norm_type = norm_type 
開發者ID:S-aiueo32,項目名稱:srntt-pytorch,代碼行數:13,代碼來源:perceptual_loss.py


注:本文中的models.VGG屬性示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。