本文整理汇总了Python中models.VGG属性的典型用法代码示例。如果您正苦于以下问题:Python models.VGG属性的具体用法?Python models.VGG怎么用?Python models.VGG使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在类models
的用法示例。
在下文中一共展示了models.VGG属性的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_model
# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def get_model(device):
"""
:param device: instance of torch.device
:return: An instance of torch.nn.Module
"""
num_classes = 2
if config["dataset"] == "Cifar100":
num_classes = 100
elif config["dataset"] == "Cifar10":
num_classes = 10
model = {
"vgg11": lambda: models.VGG("VGG11", num_classes, batch_norm=False),
"vgg11_bn": lambda: models.VGG("VGG11", num_classes, batch_norm=True),
"vgg13": lambda: models.VGG("VGG13", num_classes, batch_norm=False),
"vgg13_bn": lambda: models.VGG("VGG13", num_classes, batch_norm=True),
"vgg16": lambda: models.VGG("VGG16", num_classes, batch_norm=False),
"vgg16_bn": lambda: models.VGG("VGG16", num_classes, batch_norm=True),
"vgg19": lambda: models.VGG("VGG19", num_classes, batch_norm=False),
"vgg19_bn": lambda: models.VGG("VGG19", num_classes, batch_norm=True),
"resnet10": lambda: models.ResNet10(num_classes=num_classes),
"resnet18": lambda: models.ResNet18(num_classes=num_classes),
"resnet34": lambda: models.ResNet34(num_classes=num_classes),
"resnet50": lambda: models.ResNet50(num_classes=num_classes),
"resnet101": lambda: models.ResNet101(num_classes=num_classes),
"resnet152": lambda: models.ResNet152(num_classes=num_classes),
"bert": lambda: models.BertImage(config, num_classes=num_classes),
}[config["model"]]()
model.to(device)
if device == torch.device("cuda"):
print("Use DataParallel if multi-GPU")
model = torch.nn.DataParallel(model)
torch.backends.cudnn.benchmark = True
return model
示例2: main
# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def main(args):
dataroot = Path(args.dataroot)
save_dir = dataroot / 'map'
save_dir.mkdir(exist_ok=True)
dataset = SwappingDataset(
dataroot=dataroot, input_size=40 if 'CUFED' in dataroot.name else 80)
dataloader = DataLoader(dataset)
model = VGG(model_type='vgg19').to(device)
swapper = Swapper(args.patch_size, args.stride).to(device)
for i, batch in enumerate(tqdm(dataloader), 1):
img_in = batch['img_in'].to(device)
img_ref = batch['img_ref'].to(device)
img_ref_blur = batch['img_ref_blur'].to(device)
map_in = model(img_in, TARGET_LAYERS)
map_ref = model(img_ref, TARGET_LAYERS)
map_ref_blur = model(img_ref_blur, TARGET_LAYERS)
maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)
np.savez_compressed(save_dir / f'{batch["filename"][0]}.npz',
relu1_1=maps['relu1_1'],
relu2_1=maps['relu2_1'],
relu3_1=maps['relu3_1'],
weights=weights,
correspondences=correspondences)
if args.debug and i == 10:
break
示例3: main
# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def main(args):
imgs = load_image(args.input, args.ref)
vgg = VGG(model_type='vgg19').to(device)
swapper = Swapper().to(device)
map_in = vgg(imgs['bic'].to(device), TARGET_LAYERS)
map_ref = vgg(imgs['ref'].to(device), TARGET_LAYERS)
map_ref_blur = vgg(imgs['ref_blur'].to(device), TARGET_LAYERS)
with torch.no_grad(), timer('Feature swapping'):
maps, weights, correspondences = swapper(map_in, map_ref, map_ref_blur)
model = SRNTT(use_weights=args.use_weights).to(device)
model.load_state_dict(torch.load(args.weight))
img_hr = imgs['hr'].to(device)
img_lr = imgs['lr'].to(device)
maps = {
k: torch.tensor(v).unsqueeze(0).to(device) for k, v in maps.items()}
weights = torch.tensor(weights).reshape(1, 1, *weights.shape).to(device)
with torch.no_grad(), timer('Inference'):
_, img_sr = model(img_lr, maps, weights)
psnr = PSNR()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
ssim = SSIM()(img_sr.clamp(0, 1), img_hr.clamp(0, 1)).item()
print(f'[Result] PSNR:{psnr:.2f}, SSIM:{ssim:.4f}')
save_image(img_sr.clamp(0, 1), './out.png')
示例4: __init__
# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def __init__(self, use_weights=False):
super(TextureLoss, self).__init__()
self.use_weights = use_weights
self.model = VGG(model_type='vgg19')
self.register_buffer('a', torch.tensor(-20., requires_grad=False))
self.register_buffer('b', torch.tensor(.65, requires_grad=False))
示例5: __init__
# 需要导入模块: import models [as 别名]
# 或者: from models import VGG [as 别名]
def __init__(self,
model_type: str = 'vgg19',
target_layer: str = 'relu5_1',
norm_type: str = 'fro'):
super(PerceptualLoss, self).__init__()
assert norm_type in ['mse', 'fro']
self.model = VGG(model_type=model_type)
self.target_layer = target_layer
self.norm_type = norm_type