本文整理汇总了Python中torchvision.models.utils.load_state_dict_from_url方法的典型用法代码示例。如果您正苦于以下问题:Python utils.load_state_dict_from_url方法的具体用法?Python utils.load_state_dict_from_url怎么用?Python utils.load_state_dict_from_url使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.models.utils
的用法示例。
在下文中一共展示了utils.load_state_dict_from_url方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: inception_v3_base
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def inception_v3_base(pretrained=False, progress=True, **kwargs):
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
else:
original_aux_logits = True
model = Inception3Base(**kwargs)
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
return Inception3Base(**kwargs)
示例2: _darknet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _darknet(arch, pretrained, progress, **kwargs):
# Retrieve the correct Darknet layout type
darknet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
# Build the model
model = darknet_type(default_cfgs[arch]['layout'], **kwargs)
# Load pretrained parameters
if pretrained:
if default_cfgs[arch]['url'] is None:
logging.warning(f"Invalid model URL for {arch}, using default initialization.")
else:
state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
progress=progress)
model.load_state_dict(state_dict)
return model
示例3: fid_inception_v3
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def fid_inception_v3():
"""Build pretrained Inception model for FID computation
The Inception model for FID computation uses a different set of weights
and has a slightly different structure than torchvision's Inception.
This method first constructs torchvision's Inception and then patches the
necessary parts that are different in the FID Inception model.
"""
inception = _inception_v3(num_classes=1008,
aux_logits=False,
pretrained=False)
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
inception.Mixed_7b = FIDInceptionE_1(1280)
inception.Mixed_7c = FIDInceptionE_2(2048)
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
inception.load_state_dict(state_dict)
return inception
示例4: _resnet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
示例5: _shufflenetv2
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
model = ShuffleNetV2(*args, **kwargs)
if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict,strict=False)
return model
示例6: _resnet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict, strict=False)
print('load pretrained models from imagenet')
return model
示例7: _resnet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, pretrained=None, **kwargs):
cfgs = deepcopy(CFGS)
cfg_settings = cfgs[arch]["default"]
cfg_params = cfg_settings.pop("params")
if pretrained:
pretrained_settings = cfgs[arch][pretrained]
pretrained_params = pretrained_settings.pop("params", {})
cfg_settings.update(pretrained_settings)
cfg_params.update(pretrained_params)
common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
if common_args:
logging.warning(
f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
)
kwargs.update(cfg_params)
model = TResNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"], check_hash=True)
kwargs_cls = kwargs.get("num_classes", None)
if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
logging.warning(
"Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
cfg_settings["num_classes"], kwargs_cls
)
)
# if there is last_linear in state_dict, it's going to be overwritten
state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"]
state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"]
if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels
state_dict["conv1.1.weight"] = repeat_channels(
state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16
)
model.load_state_dict(state_dict)
patch_bn(model)
setattr(model, "pretrained_settings", cfg_settings)
return model
示例8: _vgg
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _vgg(arch, pretrained=None, **kwargs):
cfgs = deepcopy(CFGS)
cfg_settings = cfgs[arch]["default"]
cfg_params = cfg_settings.pop("params")
if pretrained:
pretrained_settings = cfgs[arch][pretrained]
pretrained_params = pretrained_settings.pop("params", {})
cfg_settings.update(pretrained_settings)
cfg_params.update(pretrained_params)
common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
if common_args:
logging.warning(
f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
)
kwargs.update(cfg_params)
model = VGG(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
kwargs_cls = kwargs.get("num_classes", None)
if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
logging.warning(
"Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
cfg_settings["num_classes"], kwargs_cls
)
)
# if there is last_linear in state_dict, it's going to be overwritten
state_dict["classifier.6.weight"] = model.state_dict()["classifier.6.weight"]
state_dict["classifier.6.bias"] = model.state_dict()["classifier.6.bias"]
model.load_state_dict(state_dict)
setattr(model, "pretrained_settings", cfg_settings)
return model
示例9: _efficientnet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _efficientnet(arch, pretrained=None, **kwargs):
cfgs = deepcopy(CFGS)
cfg_settings = cfgs[arch]["default"]
cfg_params = cfg_settings.pop("params")
cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"])
if pretrained:
pretrained_settings = cfgs[arch][pretrained]
pretrained_params = pretrained_settings.pop("params", {})
cfg_settings.update(pretrained_settings)
cfg_params.update(pretrained_params)
common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
if common_args:
logging.warning(
f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
)
kwargs.update(cfg_params)
model = EfficientNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
kwargs_cls = kwargs.get("num_classes", None)
if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
logging.warning(
"Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
cfg_settings["num_classes"], kwargs_cls
)
)
state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
if kwargs.get("in_channels", 3) != 3: # support pretrained for custom input channels
state_dict["conv_stem.weight"] = repeat_channels(
state_dict["conv_stem.weight"], kwargs["in_channels"]
)
model.load_state_dict(state_dict)
setattr(model, "pretrained_settings", cfg_settings)
return model
示例10: _densenet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _densenet(arch, pretrained=None, **kwargs):
cfgs = deepcopy(CFGS)
cfg_settings = cfgs[arch]["default"]
cfg_params = cfg_settings.pop("params")
if pretrained:
pretrained_settings = cfgs[arch][pretrained]
pretrained_params = pretrained_settings.pop("params", {})
cfg_settings.update(pretrained_settings)
cfg_params.update(pretrained_params)
common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
if common_args:
logging.warning(
f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
)
kwargs.update(cfg_params)
model = DenseNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
kwargs_cls = kwargs.get("num_classes", None)
if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
logging.warning(
"Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
cfg_settings["num_classes"], kwargs_cls
)
)
# if there is last_linear in state_dict, it's going to be overwritten
state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
model.load_state_dict(state_dict)
setattr(model, "pretrained_settings", cfg_settings)
return model
示例11: _iresnet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
示例12: _resnet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, block, planes, pretrained, progress, deconv,delinear,channel_deconv, **kwargs):
model = ResNet(block, planes,deconv=deconv,delinear=delinear,channel_deconv=channel_deconv, **kwargs)
"""
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
"""
return model
示例13: _load_model
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
if pretrained:
aux_loss = True
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
if pretrained:
arch = arch_type + '_' + backbone + '_coco'
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
示例14: _unet
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _unet(arch, pretrained, progress, **kwargs):
# Retrieve the correct Darknet layout type
unet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
# Build the model
model = unet_type(default_cfgs[arch]['layout'], **kwargs)
# Load pretrained parameters
if pretrained:
if default_cfgs[arch]['url'] is None:
logging.warning(f"Invalid model URL for {arch}, using default initialization.")
else:
state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
progress=progress)
model.load_state_dict(state_dict)
return model
示例15: _yolo
# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _yolo(arch, pretrained, progress, pretrained_backbone, **kwargs):
if pretrained:
pretrained_backbone = False
# Retrieve the correct Darknet layout type
yolo_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
# Build the model
model = yolo_type(default_cfgs[arch]['backbone']['layout'], **kwargs)
# Load backbone pretrained parameters
if pretrained_backbone:
if default_cfgs[arch]['backbone']['url'] is None:
logging.warning(f"Invalid model URL for {arch}'s backbone, using default initialization.")
else:
state_dict = load_state_dict_from_url(default_cfgs[arch]['backbone']['url'],
progress=progress)
state_dict = {k.replace('features.', ''): v
for k, v in state_dict.items() if k.startswith('features')}
model.backbone.load_state_dict(state_dict)
# Load pretrained parameters
if pretrained:
if default_cfgs[arch]['url'] is None:
logging.warning(f"Invalid model URL for {arch}, using default initialization.")
else:
state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
progress=progress)
model.load_state_dict(state_dict)
return model