當前位置: 首頁>>代碼示例>>Python>>正文


Python utils.load_state_dict_from_url方法代碼示例

本文整理匯總了Python中torchvision.models.utils.load_state_dict_from_url方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.load_state_dict_from_url方法的具體用法?Python utils.load_state_dict_from_url怎麽用?Python utils.load_state_dict_from_url使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torchvision.models.utils的用法示例。


在下文中一共展示了utils.load_state_dict_from_url方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: inception_v3_base

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def inception_v3_base(pretrained=False, progress=True, **kwargs):
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
        model = Inception3Base(**kwargs)
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
        return model

    return Inception3Base(**kwargs) 
開發者ID:krasserm,項目名稱:fairseq-image-captioning,代碼行數:21,代碼來源:inception.py

示例2: _darknet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _darknet(arch, pretrained, progress, **kwargs):

    # Retrieve the correct Darknet layout type
    darknet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = darknet_type(default_cfgs[arch]['layout'], **kwargs)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
開發者ID:frgfm,項目名稱:Holocron,代碼行數:18,代碼來源:darknet.py

示例3: fid_inception_v3

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def fid_inception_v3():
    """Build pretrained Inception model for FID computation

    The Inception model for FID computation uses a different set of weights
    and has a slightly different structure than torchvision's Inception.

    This method first constructs torchvision's Inception and then patches the
    necessary parts that are different in the FID Inception model.
    """
    inception = _inception_v3(num_classes=1008,
                              aux_logits=False,
                              pretrained=False)
    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
    inception.Mixed_7b = FIDInceptionE_1(1280)
    inception.Mixed_7c = FIDInceptionE_2(2048)

    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception 
開發者ID:mseitzer,項目名稱:pytorch-fid,代碼行數:27,代碼來源:inception.py

示例4: _resnet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
開發者ID:legolas123,項目名稱:cv-tricks.com,代碼行數:9,代碼來源:resnet_preact_bin.py

示例5: _shufflenetv2

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
    model = ShuffleNetV2(*args, **kwargs)

    if pretrained:
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict,strict=False)

    return model 
開發者ID:SURFZJY,項目名稱:Real-time-Text-Detection,代碼行數:14,代碼來源:shufflenetv2.py

示例6: _resnet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
        print('load pretrained models from imagenet')
    return model 
開發者ID:SURFZJY,項目名稱:Real-time-Text-Detection,代碼行數:10,代碼來源:resnet.py

示例7: _resnet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _resnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = TResNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"], check_hash=True)
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["last_linear.weight"] = model.state_dict()["last_linear.weight"]
            state_dict["last_linear.bias"] = model.state_dict()["last_linear.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv1.1.weight"] = repeat_channels(
                state_dict["conv1.1.weight"], kwargs["in_channels"] * 16, 3 * 16
            )
        model.load_state_dict(state_dict)
        patch_bn(model)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
開發者ID:bonlime,項目名稱:pytorch-tools,代碼行數:38,代碼來源:tresnet.py

示例8: _vgg

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _vgg(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = VGG(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["classifier.6.weight"] = model.state_dict()["classifier.6.weight"]
            state_dict["classifier.6.bias"] = model.state_dict()["classifier.6.bias"]
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
開發者ID:bonlime,項目名稱:pytorch-tools,代碼行數:33,代碼來源:vgg.py

示例9: _efficientnet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _efficientnet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    cfg_params["blocks_args"] = decode_block_args(cfg_params["blocks_args"])
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)
    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = EfficientNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        if kwargs.get("in_channels", 3) != 3:  # support pretrained for custom input channels
            state_dict["conv_stem.weight"] = repeat_channels(
                state_dict["conv_stem.weight"], kwargs["in_channels"]
            )
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
開發者ID:bonlime,項目名稱:pytorch-tools,代碼行數:37,代碼來源:efficientnet.py

示例10: _densenet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _densenet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    if pretrained:
        pretrained_settings = cfgs[arch][pretrained]
        pretrained_params = pretrained_settings.pop("params", {})
        cfg_settings.update(pretrained_settings)
        cfg_params.update(pretrained_params)

    common_args = set(cfg_params.keys()).intersection(set(kwargs.keys()))
    if common_args:
        logging.warning(
            f"Args {common_args} are going to be overwritten by default params for {pretrained} weights"
        )
    kwargs.update(cfg_params)
    model = DenseNet(**kwargs)

    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                "Using model pretrained for {} classes with {} classes. Last layer is initialized randomly".format(
                    cfg_settings["num_classes"], kwargs_cls
                )
            )
            # if there is last_linear in state_dict, it's going to be overwritten
            state_dict["classifier.weight"] = model.state_dict()["classifier.weight"]
            state_dict["classifier.bias"] = model.state_dict()["classifier.bias"]
        model.load_state_dict(state_dict)

    setattr(model, "pretrained_settings", cfg_settings)
    return model 
開發者ID:bonlime,項目名稱:pytorch-tools,代碼行數:36,代碼來源:densenet.py

示例11: _iresnet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model 
開發者ID:nizhib,項目名稱:pytorch-insightface,代碼行數:9,代碼來源:iresnet.py

示例12: _resnet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _resnet(arch, block, planes, pretrained, progress, deconv,delinear,channel_deconv, **kwargs):
    model = ResNet(block, planes,deconv=deconv,delinear=delinear,channel_deconv=channel_deconv, **kwargs)
    """
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    """
    return model 
開發者ID:yechengxi,項目名稱:deconvolution,代碼行數:11,代碼來源:resnet_imagenet.py

示例13: _load_model

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    if pretrained:
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
    if pretrained:
        arch = arch_type + '_' + backbone + '_coco'
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)
    return model 
開發者ID:yechengxi,項目名稱:deconvolution,代碼行數:15,代碼來源:segmentation.py

示例14: _unet

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _unet(arch, pretrained, progress, **kwargs):
    # Retrieve the correct Darknet layout type
    unet_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = unet_type(default_cfgs[arch]['layout'], **kwargs)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
開發者ID:frgfm,項目名稱:Holocron,代碼行數:17,代碼來源:unet.py

示例15: _yolo

# 需要導入模塊: from torchvision.models import utils [as 別名]
# 或者: from torchvision.models.utils import load_state_dict_from_url [as 別名]
def _yolo(arch, pretrained, progress, pretrained_backbone, **kwargs):

    if pretrained:
        pretrained_backbone = False

    # Retrieve the correct Darknet layout type
    yolo_type = sys.modules[__name__].__dict__[default_cfgs[arch]['arch']]
    # Build the model
    model = yolo_type(default_cfgs[arch]['backbone']['layout'], **kwargs)
    # Load backbone pretrained parameters
    if pretrained_backbone:
        if default_cfgs[arch]['backbone']['url'] is None:
            logging.warning(f"Invalid model URL for {arch}'s backbone, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['backbone']['url'],
                                                  progress=progress)
            state_dict = {k.replace('features.', ''): v
                          for k, v in state_dict.items() if k.startswith('features')}
            model.backbone.load_state_dict(state_dict)
    # Load pretrained parameters
    if pretrained:
        if default_cfgs[arch]['url'] is None:
            logging.warning(f"Invalid model URL for {arch}, using default initialization.")
        else:
            state_dict = load_state_dict_from_url(default_cfgs[arch]['url'],
                                                  progress=progress)
            model.load_state_dict(state_dict)

    return model 
開發者ID:frgfm,項目名稱:Holocron,代碼行數:31,代碼來源:yolo.py


注:本文中的torchvision.models.utils.load_state_dict_from_url方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。