当前位置: 首页>>代码示例>>Python>>正文


Python utils.load_state_dict_from_url方法代码示例

本文整理汇总了Python中torchvision.models.utils.load_state_dict_from_url方法的典型用法代码示例。如果您正苦于以下问题:Python utils.load_state_dict_from_url方法的具体用法?Python utils.load_state_dict_from_url怎么用?Python utils.load_state_dict_from_url使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.models.utils的用法示例。


在下文中一共展示了utils.load_state_dict_from_url方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: inception_v3_base

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def inception_v3_base(pretrained=False, progress=True, **kwargs):
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        model = Inception3Base(**kwargs)
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

    return Inception3Base(**kwargs) 
开发者ID:krasserm,项目名称:fairseq-image-captioning,代码行数:21,代码来源:inception.py

示例2: _darknet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _darknet(arch, pretrained, progress, **kwargs):

    # Retrieve the correct Darknet layout type
    darknet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = darknet_type(default_cfgs[arch]['layout'], **kwargs)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
开发者ID:frgfm,项目名称:Holocron,代码行数:18,代码来源:darknet.py

示例3: fid_inception_v3

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    inception = _inception_v3(num_classes=1008,
                              aux_logits=False,
                              pretrained=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception 
开发者ID:mseitzer,项目名称:pytorch-fid,代码行数:27,代码来源:inception.py

示例4: _resnet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
开发者ID:legolas123,项目名称:cv-tricks.com,代码行数:9,代码来源:resnet_preact_bin.py

示例5: _shufflenetv2

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
    model = ShuffleNetV2(*args, **kwargs)

    if pretrained:
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict,strict=False)

    return model 
开发者ID:SURFZJY,项目名称:Real-time-Text-Detection,代码行数:14,代码来源:shufflenetv2.py

示例6: _resnet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
        print('load pretrained models from imagenet')
    return model 
开发者ID:SURFZJY,项目名称:Real-time-Text-Detection,代码行数:10,代码来源:resnet.py

示例7: _resnet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = TResNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"], check_hash=True)
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"]
            state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv1.1.weight"] = repeat_channels(
                state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16
            )
        model.load_state_dict(state_dict)
        patch_bn(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:38,代码来源:tresnet.py

示例8: _vgg

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _vgg(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = VGG(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["classifier.6.weight"] = model.state_dict()["classifier.6.weight"]
            state_dict["classifier.6.bias"] = model.state_dict()["classifier.6.bias"]
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:33,代码来源:vgg.py

示例9: _efficientnet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _efficientnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"])
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = EfficientNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv_stem.weight"] = repeat_channels(
                state_dict["conv_stem.weight"], kwargs["in_channels"]
            )
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:37,代码来源:efficientnet.py

示例10: _densenet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _densenet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)

    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = DenseNet(**kwargs)

    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        model.load_state_dict(state_dict)

    setattr(model, "pretrained_settings", cfg_settings)
    return model 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:36,代码来源:densenet.py

示例11: _iresnet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
开发者ID:nizhib,项目名称:pytorch-insightface,代码行数:9,代码来源:iresnet.py

示例12: _resnet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _resnet(arch, block, planes, pretrained, progress, deconv,delinear,channel_deconv, **kwargs):
    model = ResNet(block, planes,deconv=deconv,delinear=delinear,channel_deconv=channel_deconv, **kwargs)
    """
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    """
    return model 
开发者ID:yechengxi,项目名称:deconvolution,代码行数:11,代码来源:resnet_imagenet.py

示例13: _load_model

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    if pretrained:
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
    if pretrained:
        arch = arch_type + '_' + backbone + '_coco'
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)
    return model 
开发者ID:yechengxi,项目名称:deconvolution,代码行数:15,代码来源:segmentation.py

示例14: _unet

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _unet(arch, pretrained, progress, **kwargs):
    # Retrieve the correct Darknet layout type
    unet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = unet_type(default_cfgs[arch]['layout'], **kwargs)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
开发者ID:frgfm,项目名称:Holocron,代码行数:17,代码来源:unet.py

示例15: _yolo

# 需要导入模块: from torchvision.models import utils [as 别名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 别名]
def _yolo(arch, pretrained, progress, pretrained_backbone, **kwargs):

    if pretrained:
        pretrained_backbone = False

    # Retrieve the correct Darknet layout type
    yolo_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = yolo_type(default_cfgs[arch]['backbone']['layout'], **kwargs)
    # Load backbone pretrained parameters
    if pretrained_backbone:
        if default_cfgs[arch]['backbone']['url'] is None:
            logging.warning(f"Invalid model URL for {arch}'s backbone, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['backbone']['url'],
                                                  progress=progress)
            state_dict = {k.replace('features.', ''): v
                          for k, v in state_dict.items() if k.startswith('features')}
            model.backbone.load_state_dict(state_dict)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
开发者ID:frgfm,项目名称:Holocron,代码行数:31,代码来源:yolo.py


注:本文中的torchvision.models.utils.load_state_dict_from_url方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。