当前位置: 首页>>代码示例>>Python>>正文


Python resnet.model_urls方法代码示例

本文整理汇总了Python中torchvision.models.resnet.model_urls方法的典型用法代码示例。如果您正苦于以下问题:Python resnet.model_urls方法的具体用法?Python resnet.model_urls怎么用?Python resnet.model_urls使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.models.resnet的用法示例。


在下文中一共展示了resnet.model_urls方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def __init__(self):
        super(Model2, self).__init__()

        # fine tuning the ResNet helped significantly with the accuracy
        base_model = MyResNet(BasicBlock, [2, 2, 2, 2])
        base_model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        # code needed to deactivate fine tuning of resnet
        #for param in base_model.parameters():
        #    param.requires_grad = False
        self.base_model = base_model
        self.drop0 = nn.Dropout2d(0.05)

        self.conv1 = nn.Conv2d(512, 256, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.drop1 = nn.Dropout2d(0.05)

        self.conv2 = nn.Conv2d(256, 128, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.drop2 = nn.Dropout2d(0.05)

        self.conv3 = nn.Conv2d(128, 1+9, 3, padding=1, bias=False) 
开发者ID:aleju,项目名称:cat-bbs,代码行数:23,代码来源:model.py

示例2: init

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def init(self, model_dir=None, gain=1.):
        self.model_dir = model_dir if model_dir is not None else self.model_dir
        sd = model_zoo.load_url(model_urls['resnet18'], model_dir=self.model_dir)
        # sd = model_zoo.load_url(model_urls['resnet34'], model_dir='./models/')
        del sd['fc.weight']
        del sd['fc.bias']
        self.load_state_dict(sd, strict=False)

        # for idx in range(len(self.stem)):
        #     m = self.stem[idx]
        #     if hasattr(m, 'weight') and not isinstance(m, torch.nn.BatchNorm2d):
        #         # torch.nn.init.kaiming_normal_(self.stem.weight, mode='fan_in', nonlinearity='linear')
        #         torch.nn.init.xavier_normal_(m.weight, gain=gain)
        #         LOGGER.debug('initialize stem weight')
        #
        # for idx in range(len(self.conv1d)):
        #     m = self.conv1d[idx]
        #     if hasattr(m, 'weight') and not isinstance(m, torch.nn.BatchNorm1d):
        #         # torch.nn.init.kaiming_normal_(self.stem.weight, mode='fan_in', nonlinearity='linear')
        #         torch.nn.init.xavier_normal_(m.weight, gain=gain)
        #         LOGGER.debug('initialize conv1d weight')

        # torch.nn.init.kaiming_uniform_(self.fc.weight, mode='fan_in', nonlinearity='sigmoid')
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
        LOGGER.debug('initialize classifier weight') 
开发者ID:kakaobrain,项目名称:autoclint,代码行数:27,代码来源:resnet.py

示例3: __init__

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def __init__(self, raw_model_dir, use_flow, logger):
        super(BackboneModel, self).__init__()
        self.use_flow = use_flow
        model = ResNet(Bottleneck, [3, 4, 6, 3])

        model.load_state_dict(
            model_zoo.load_url(model_urls['resnet50'], model_dir=raw_model_dir))
        logger.info('Model restored from pretrained resnet50')

        self.feature = nn.Sequential(*list(model.children())[:-2])
        self.base = list(self.feature.parameters())

        if self.use_flow:
            self.flow_branch = self.get_flow_branch(model)
            self.rgb_branch = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool)
            self.fuse_branch = nn.Sequential(*list(model.children())[4:-2])
        self.fea_dim = model.fc.in_features 
开发者ID:yolomax,项目名称:person-reid-lib,代码行数:19,代码来源:resnet50.py

示例4: resnet18_ids

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet18_ids(num_attributes, ids_embedding_size, pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    classifier = ResNetClassifier(BasicBlock, num_classes=num_attributes, **kwargs)
    classifier_ids = ResNetClassifier(BasicBlock, num_classes=ids_embedding_size, **kwargs)
    if pretrained:
        state_dict = model_zoo.load_url(model_urls['resnet18'])
        model.load_state_dict(
            {k: v for k, v in state_dict.items() if k in model.state_dict()}
        )

    return model, classifier, classifier_ids 
开发者ID:leokarlin,项目名称:LaSO,代码行数:18,代码来源:resnet_backup.py

示例5: resnet18

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet18(output_layers=None, pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    """

    if output_layers is None:
        output_layers = ['default']
    else:
        for l in output_layers:
            if l not in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']:
                raise ValueError('Unknown layer: {}'.format(l))

    model = ResNet(BasicBlock, [2, 2, 2, 2], output_layers, **kwargs)

    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model 
开发者ID:visionml,项目名称:pytracking,代码行数:18,代码来源:resnet.py

示例6: resnet50

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet50(num_classes=1000, avgpool_size=7, use_dropout=False, pretrained=True):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, avgpool_size=avgpool_size,
                   use_dropout=use_dropout)

    if pretrained:
        state_dict = resnet.model_zoo.load_url(resnet.model_urls['resnet50'])

        current_state = model.state_dict()
        keys = list(state_dict.keys())
        for key in keys:
            if not key.startswith('fc.'):
                current_state[key] = state_dict[key]

        model.load_state_dict(current_state)
    return model 
开发者ID:Britefury,项目名称:self-ensemble-visual-domain-adapt-photo,代码行数:22,代码来源:network_architectures.py

示例7: resnet101

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet101(num_classes=1000, avgpool_size=7, use_dropout=False, pretrained=True):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, avgpool_size=avgpool_size,
                   use_dropout=use_dropout)

    if pretrained:
        state_dict = resnet.model_zoo.load_url(resnet.model_urls['resnet101'])

        current_state = model.state_dict()
        keys = list(state_dict.keys())
        for key in keys:
            if not key.startswith('fc.'):
                current_state[key] = state_dict[key]

        model.load_state_dict(current_state)
    return model 
开发者ID:Britefury,项目名称:self-ensemble-visual-domain-adapt-photo,代码行数:22,代码来源:network_architectures.py

示例8: resnet152

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet152(num_classes=1000, avgpool_size=7, use_dropout=False, pretrained=True):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(resnet.Bottleneck, [3, 8, 36, 3], num_classes=num_classes, avgpool_size=avgpool_size,
                   use_dropout=use_dropout)

    if pretrained:
        state_dict = resnet.model_zoo.load_url(resnet.model_urls['resnet152'])

        current_state = model.state_dict()
        keys = list(state_dict.keys())
        for key in keys:
            if not key.startswith('fc.'):
                current_state[key] = state_dict[key]

        model.load_state_dict(current_state)
    return model 
开发者ID:Britefury,项目名称:self-ensemble-visual-domain-adapt-photo,代码行数:22,代码来源:network_architectures.py

示例9: trinet

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def trinet(**kwargs):
    """Creates a TriNet network and loads the pretrained ResNet50 weights.
    
    https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2
    """


    model = TriNet(Bottleneck, [3, 4, 6, 3], **kwargs)

    pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    model_dict = model.state_dict()

    # filter out fully connected keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("fc")}
    #for key, value in pretrained_dict.items():
    #    print(key)

    # overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # load the new state dict
    model.load_state_dict(model_dict)
    endpoints = {}
    endpoints["emb"] = None
    return model, endpoints 
开发者ID:kilsenp,项目名称:triplet-reid-pytorch,代码行数:26,代码来源:trinet.py

示例10: stride_test

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def stride_test(**kwargs):


    model = StrideTest(Bottleneck, [3, 4, 6, 3], **kwargs)
    pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    model_dict = model.state_dict()

    # filter out fully connected keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("fc")}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() 
                       if  not (k.startswith("layer4") and "downsample" in k)}

    #pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("layer4.0")}
    #for key, value in pretrained_dict.items():
    #    print(key)

    # overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # load the new state dict
    model.load_state_dict(model_dict)
    endpoints = {}
    endpoints["emb"] = None
    return model 
开发者ID:kilsenp,项目名称:triplet-reid-pytorch,代码行数:25,代码来源:dilated.py

示例11: resnet18

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet18(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, BasicBlock, [2, 2, 2, 2], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet18']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:13,代码来源:resnet.py

示例12: resnet34

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet34(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, BasicBlock, [3, 4, 6, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet34']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:13,代码来源:resnet.py

示例13: resnet50

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet50(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 6, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet50']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:13,代码来源:resnet.py

示例14: resnet101

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet101(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 23, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet101']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:13,代码来源:resnet.py

示例15: resnet152

# 需要导入模块: from torchvision.models import resnet [as 别名]
# 或者: from torchvision.models.resnet import model_urls [as 别名]
def resnet152(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 8, 36, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet152']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
开发者ID:ruiminshen,项目名称:yolo2-pytorch,代码行数:13,代码来源:resnet.py


注:本文中的torchvision.models.resnet.model_urls方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。