当前位置: 首页>>代码示例>>Python>>正文


Python hub.load_state_dict_from_url方法代码示例

本文整理汇总了Python中torch.hub.load_state_dict_from_url方法的典型用法代码示例。如果您正苦于以下问题:Python hub.load_state_dict_from_url方法的具体用法?Python hub.load_state_dict_from_url怎么用?Python hub.load_state_dict_from_url使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.hub的用法示例。


在下文中一共展示了hub.load_state_dict_from_url方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def __init__(self, arch, replace_stride_with_dilation=None, multi_grid=None, pretrain=True,
                 norm_cfg=None, act_cfg=None):
        cfg = MODEL_CFGS[arch]
        super().__init__(
            cfg['block'],
            cfg['layer'],
            replace_stride_with_dilation=replace_stride_with_dilation,
            multi_grid=multi_grid,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)

        if pretrain:
            logger.info('ResNet init weights from pretreain')
            state_dict = load_state_dict_from_url(cfg['weights_url'])
            self.load_state_dict(state_dict, strict=False)
        else:
            logger.info('ResNet init weights')
            init_weights(self.modules())

        del self.fc, self.avgpool 
开发者ID:Media-Smart,项目名称:vedaseg,代码行数:22,代码来源:resnet.py

示例2: gsc_super_sparse_cnn

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def gsc_super_sparse_cnn(pretrained=False, progress=True):
    """
    Super Sparse CNN model used to classify `Google Speech Commands`
    dataset as described in `How Can We Be So Dense?`_ paper.
    This model provides a sparser version of :class:`GSCSparseCNN`

    :param pretrained: If True, returns a model pre-trained on Google Speech Commands
    :param progress: If True, displays a progress bar of the download to stderr
    """
    model = GSCSuperSparseCNN()
    if pretrained:
        state_dict = load_state_dict_from_url(
            MODEL_URLS["gsc_super_sparse_cnn"], progress=progress
        )
        model.load_state_dict(state_dict)
    return model 
开发者ID:numenta,项目名称:nupic.torch,代码行数:18,代码来源:sparse_cnn.py

示例3: dpn68

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def dpn68(pretrained=False, test_time_pool=False, **kwargs):
    """Constructs a DPN-68 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet-1K
        test_time_pool (bool): If True, pools features for input resolution beyond
            standard 224x224 input with avg+max at inference/validation time

        **kwargs : Keyword args passed to model __init__
            num_classes (int): Number of classes for classifier linear layer, default=1000
    """
    model = DPN(
        small=True, num_init_features=10, k_r=128, groups=32,
        k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
        test_time_pool=test_time_pool, **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(model_urls['dpn68']))
    return model 
开发者ID:rwightman,项目名称:pytorch-dpn-pretrained,代码行数:20,代码来源:dpn.py

示例4: dpn68b

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def dpn68b(pretrained=False, test_time_pool=False, **kwargs):
    """Constructs a DPN-68b model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet-1K
        test_time_pool (bool): If True, pools features for input resolution beyond
            standard 224x224 input with avg+max at inference/validation time

        **kwargs : Keyword args passed to model __init__
            num_classes (int): Number of classes for classifier linear layer, default=1000
    """
    model = DPN(
        small=True, num_init_features=10, k_r=128, groups=32,
        b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
        test_time_pool=test_time_pool, **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(model_urls['dpn68b-extra']))
    return model 
开发者ID:rwightman,项目名称:pytorch-dpn-pretrained,代码行数:20,代码来源:dpn.py

示例5: dpn92

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def dpn92(pretrained=False, test_time_pool=False, **kwargs):
    """Constructs a DPN-92 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet-1K
        test_time_pool (bool): If True, pools features for input resolution beyond
            standard 224x224 input with avg+max at inference/validation time

        **kwargs : Keyword args passed to model __init__
            num_classes (int): Number of classes for classifier linear layer, default=1000
    """
    model = DPN(
        num_init_features=64, k_r=96, groups=32,
        k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
        test_time_pool=test_time_pool, **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(model_urls['dpn92-extra']))
    return model 
开发者ID:rwightman,项目名称:pytorch-dpn-pretrained,代码行数:20,代码来源:dpn.py

示例6: dpn98

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def dpn98(pretrained=False, test_time_pool=False, **kwargs):
    """Constructs a DPN-98 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet-1K
        test_time_pool (bool): If True, pools features for input resolution beyond
            standard 224x224 input with avg+max at inference/validation time

        **kwargs : Keyword args passed to model __init__
            num_classes (int): Number of classes for classifier linear layer, default=1000
    """
    model = DPN(
        num_init_features=96, k_r=160, groups=40,
        k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
        test_time_pool=test_time_pool, **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(model_urls['dpn98']))
    return model 
开发者ID:rwightman,项目名称:pytorch-dpn-pretrained,代码行数:20,代码来源:dpn.py

示例7: dpn131

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def dpn131(pretrained=False, test_time_pool=False, **kwargs):
    """Constructs a DPN-131 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet-1K
        test_time_pool (bool): If True, pools features for input resolution beyond
            standard 224x224 input with avg+max at inference/validation time

        **kwargs : Keyword args passed to model __init__
            num_classes (int): Number of classes for classifier linear layer, default=1000
    """
    model = DPN(
        num_init_features=128, k_r=160, groups=40,
        k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
        test_time_pool=test_time_pool, **kwargs)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(model_urls['dpn131']))
    return model 
开发者ID:rwightman,项目名称:pytorch-dpn-pretrained,代码行数:20,代码来源:dpn.py

示例8: _efficientdet

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def _efficientdet(arch, pretrained=None, **kwargs):
    cfgs = deepcopy(CFGS)
    cfg_settings = cfgs[arch]["default"]
    cfg_params = cfg_settings.pop("params")
    kwargs.update(cfg_params)
    model = EfficientDet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
        kwargs_cls = kwargs.get("num_classes", None)
        if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
            logging.warning(
                f"Using model pretrained for {cfg_settings['num_classes']} classes with {kwargs_cls} classes. Last layer is initialized randomly"
            )
            last_conv_name = f"cls_head_convs.{kwargs['num_head_repeats']}.1"
            state_dict[f"{last_conv_name}.weight"] = model.state_dict()[f"{last_conv_name}.weight"]
            state_dict[f"{last_conv_name}.bias"] = model.state_dict()[f"{last_conv_name}.bias"]
        model.load_state_dict(state_dict)
    setattr(model, "pretrained_settings", cfg_settings)
    return model 
开发者ID:bonlime,项目名称:pytorch-tools,代码行数:21,代码来源:efficientdet.py

示例9: __init__

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, progress=True):
        super().__init__(make_layers())
        if pretrained:
            state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress)
            super().load_state_dict(state_dict)

        self.preprocess = preprocess
        self.postprocess = postprocess
        if self.postprocess:
            self.pproc = Postprocessor()
            if pretrained:
                state_dict = hub.load_state_dict_from_url(urls['pca'], progress=progress)
                # TODO: Convert the state_dict to torch
                state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor(
                    state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float
                )
                state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor(
                    state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float
                )

                self.pproc.load_state_dict(state_dict) 
开发者ID:harritaylor,项目名称:torchvggish,代码行数:23,代码来源:vggish.py

示例10: _get_model_by_name

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def _get_model_by_name(model_name, classes=1000, pretrained=False):
    block_args_list, global_params = get_efficientnet_params(model_name, override_params={'num_classes': classes})
    model = EfficientNet(block_args_list, global_params)
    try:
        if pretrained:
            pretrained_state_dict = load_state_dict_from_url(IMAGENET_WEIGHTS[model_name])

            if classes != 1000:
                random_state_dict = model.state_dict()
                pretrained_state_dict['_fc.weight'] = random_state_dict['_fc.weight']
                pretrained_state_dict['_fc.bias'] = random_state_dict['_fc.bias']

            model.load_state_dict(pretrained_state_dict)

    except KeyError as e:
        print(f"NOTE: Currently model {e} doesn't have pretrained weights, therefore a model with randomly initialized"
              " weights is returned.")

    return model 
开发者ID:zhoudaxia233,项目名称:EfficientUnet-PyTorch,代码行数:21,代码来源:efficientnet.py

示例11: mobilenetv3

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def mobilenetv3(input_size=224, num_classes=1000, scale=1., in_channels=3, drop_prob=0.0, num_steps=3e5, start_step=0,
                small=False, get_weights=True, progress=True):
    model = MobileNetV3(num_classes=num_classes, scale=scale, in_channels=in_channels, drop_prob=drop_prob,
                        num_steps=num_steps, start_step=start_step, small=small)
    name = 'mobilenetv3_{}_{}_{}'.format('small' if small else 'large', scale, input_size)
    if get_weights:
        if name in model_urls:
            state_dict = load_state_dict_from_url(model_urls[name], progress=progress, map_location='cpu')
            model.load_state_dict(state_dict)
        else:
            raise ValueError
    return model 
开发者ID:Randl,项目名称:MobileNetV3-pytorch,代码行数:14,代码来源:MobileNetV3.py

示例12: mobilenet_v2

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def mobilenet_v2(pretrained=True):
    model = MobileNetV2(width_mult=1)

    if pretrained:
        try:
            from torch.hub import load_state_dict_from_url
        except ImportError:
            from torch.utils.model_zoo import load_url as load_state_dict_from_url
        state_dict = load_state_dict_from_url(
            'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True)
        model.load_state_dict(state_dict)
    return model 
开发者ID:CMU-CREATE-Lab,项目名称:deep-smoke-machine,代码行数:14,代码来源:mobilenet_v2.py

示例13: __init__

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def __init__(
            self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        state_dict = load_state_dict_from_url(
            model_url,
            map_location=torch_utils.get_device(),
            progress=True)
        self.net = SSD(resnet152_model_config)
        self.net.load_state_dict(state_dict)
        self.net.eval()
        self.net = self.net.to(self.device) 
开发者ID:hukkelas,项目名称:DSFD-Pytorch-Inference,代码行数:13,代码来源:detect.py

示例14: __init__

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def __init__(
            self,
            model: str,
            *args,
            **kwargs):
        super().__init__(*args, **kwargs)
        if model == "mobilenet":
            cfg = cfg_mnet
            state_dict = load_state_dict_from_url(
                "https://folk.ntnu.no/haakohu/RetinaFace_mobilenet025.pth",
                map_location=torch_utils.get_device()
            )
        else:
            assert model == "resnet50"
            cfg = cfg_re50
            state_dict = load_state_dict_from_url(
                "https://folk.ntnu.no/haakohu/RetinaFace_ResNet50.pth",
                map_location=torch_utils.get_device()
            )
            state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        net = RetinaFace(cfg=cfg)
        net.eval()
        net.load_state_dict(state_dict)
        self.cfg = cfg
        self.net = net.to(self.device)
        self.mean = np.array([104, 117, 123], dtype=np.float32) 
开发者ID:hukkelas,项目名称:DSFD-Pytorch-Inference,代码行数:28,代码来源:detect.py

示例15: load_pretrained

# 需要导入模块: from torch import hub [as 别名]
# 或者: from torch.hub import load_state_dict_from_url [as 别名]
def load_pretrained(model, url, filter_fn=None, strict=True):
    if not url:
        print("=> Warning: Pretrained model URL is empty, using random initialization.")
        return

    state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')

    input_conv = 'conv_stem'
    classifier = 'classifier'
    in_chans = getattr(model, input_conv).weight.shape[1]
    num_classes = getattr(model, classifier).weight.shape[0]

    input_conv_weight = input_conv + '.weight'
    pretrained_in_chans = state_dict[input_conv_weight].shape[1]
    if in_chans != pretrained_in_chans:
        if in_chans == 1:
            print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
                input_conv_weight, pretrained_in_chans))
            conv1_weight = state_dict[input_conv_weight]
            state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
        else:
            print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
                input_conv_weight, pretrained_in_chans))
            del state_dict[input_conv_weight]
            strict = False

    classifier_weight = classifier + '.weight'
    pretrained_num_classes = state_dict[classifier_weight].shape[0]
    if num_classes != pretrained_num_classes:
        print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
        del state_dict[classifier_weight]
        del state_dict[classifier + '.bias']
        strict = False

    if filter_fn is not None:
        state_dict = filter_fn(state_dict)

    model.load_state_dict(state_dict, strict=strict) 
开发者ID:rwightman,项目名称:gen-efficientnet-pytorch,代码行数:40,代码来源:helpers.py


注:本文中的torch.hub.load_state_dict_from_url方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。