當前位置: 首頁>>代碼示例>>Python>>正文


Python resnet.model_urls方法代碼示例

本文整理匯總了Python中torchvision.models.resnet.model_urls方法的典型用法代碼示例。如果您正苦於以下問題:Python resnet.model_urls方法的具體用法?Python resnet.model_urls怎麽用?Python resnet.model_urls使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torchvision.models.resnet的用法示例。


在下文中一共展示了resnet.model_urls方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def __init__(self):
        super(Model2, self).__init__()

        # fine tuning the ResNet helped significantly with the accuracy
        base_model = MyResNet(BasicBlock, [2, 2, 2, 2])
        base_model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        # code needed to deactivate fine tuning of resnet
        #for param in base_model.parameters():
        #    param.requires_grad = False
        self.base_model = base_model
        self.drop0 = nn.Dropout2d(0.05)

        self.conv1 = nn.Conv2d(512, 256, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.drop1 = nn.Dropout2d(0.05)

        self.conv2 = nn.Conv2d(256, 128, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.drop2 = nn.Dropout2d(0.05)

        self.conv3 = nn.Conv2d(128, 1+9, 3, padding=1, bias=False) 
開發者ID:aleju,項目名稱:cat-bbs,代碼行數:23,代碼來源:model.py

示例2: init

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def init(self, model_dir=None, gain=1.):
        self.model_dir = model_dir if model_dir is not None else self.model_dir
        sd = model_zoo.load_url(model_urls['resnet18'], model_dir=self.model_dir)
        # sd = model_zoo.load_url(model_urls['resnet34'], model_dir='./models/')
        del sd['fc.weight']
        del sd['fc.bias']
        self.load_state_dict(sd, strict=False)

        # for idx in range(len(self.stem)):
        #     m = self.stem[idx]
        #     if hasattr(m, 'weight') and not isinstance(m, torch.nn.BatchNorm2d):
        #         # torch.nn.init.kaiming_normal_(self.stem.weight, mode='fan_in', nonlinearity='linear')
        #         torch.nn.init.xavier_normal_(m.weight, gain=gain)
        #         LOGGER.debug('initialize stem weight')
        #
        # for idx in range(len(self.conv1d)):
        #     m = self.conv1d[idx]
        #     if hasattr(m, 'weight') and not isinstance(m, torch.nn.BatchNorm1d):
        #         # torch.nn.init.kaiming_normal_(self.stem.weight, mode='fan_in', nonlinearity='linear')
        #         torch.nn.init.xavier_normal_(m.weight, gain=gain)
        #         LOGGER.debug('initialize conv1d weight')

        # torch.nn.init.kaiming_uniform_(self.fc.weight, mode='fan_in', nonlinearity='sigmoid')
        torch.nn.init.xavier_uniform_(self.fc.weight, gain=gain)
        LOGGER.debug('initialize classifier weight') 
開發者ID:kakaobrain,項目名稱:autoclint,代碼行數:27,代碼來源:resnet.py

示例3: __init__

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def __init__(self, raw_model_dir, use_flow, logger):
        super(BackboneModel, self).__init__()
        self.use_flow = use_flow
        model = ResNet(Bottleneck, [3, 4, 6, 3])

        model.load_state_dict(
            model_zoo.load_url(model_urls['resnet50'], model_dir=raw_model_dir))
        logger.info('Model restored from pretrained resnet50')

        self.feature = nn.Sequential(*list(model.children())[:-2])
        self.base = list(self.feature.parameters())

        if self.use_flow:
            self.flow_branch = self.get_flow_branch(model)
            self.rgb_branch = nn.Sequential(model.conv1, model.bn1, model.relu, model.maxpool)
            self.fuse_branch = nn.Sequential(*list(model.children())[4:-2])
        self.fea_dim = model.fc.in_features 
開發者ID:yolomax,項目名稱:person-reid-lib,代碼行數:19,代碼來源:resnet50.py

示例4: resnet18_ids

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet18_ids(num_attributes, ids_embedding_size, pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    classifier = ResNetClassifier(BasicBlock, num_classes=num_attributes, **kwargs)
    classifier_ids = ResNetClassifier(BasicBlock, num_classes=ids_embedding_size, **kwargs)
    if pretrained:
        state_dict = model_zoo.load_url(model_urls['resnet18'])
        model.load_state_dict(
            {k: v for k, v in state_dict.items() if k in model.state_dict()}
        )

    return model, classifier, classifier_ids 
開發者ID:leokarlin,項目名稱:LaSO,代碼行數:18,代碼來源:resnet_backup.py

示例5: resnet18

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet18(output_layers=None, pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    """

    if output_layers is None:
        output_layers = ['default']
    else:
        for l in output_layers:
            if l not in ['conv1', 'layer1', 'layer2', 'layer3', 'layer4', 'fc']:
                raise ValueError('Unknown layer: {}'.format(l))

    model = ResNet(BasicBlock, [2, 2, 2, 2], output_layers, **kwargs)

    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model 
開發者ID:visionml,項目名稱:pytracking,代碼行數:18,代碼來源:resnet.py

示例6: resnet50

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet50(num_classes=1000, avgpool_size=7, use_dropout=False, pretrained=True):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, avgpool_size=avgpool_size,
                   use_dropout=use_dropout)

    if pretrained:
        state_dict = resnet.model_zoo.load_url(resnet.model_urls['resnet50'])

        current_state = model.state_dict()
        keys = list(state_dict.keys())
        for key in keys:
            if not key.startswith('fc.'):
                current_state[key] = state_dict[key]

        model.load_state_dict(current_state)
    return model 
開發者ID:Britefury,項目名稱:self-ensemble-visual-domain-adapt-photo,代碼行數:22,代碼來源:network_architectures.py

示例7: resnet101

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet101(num_classes=1000, avgpool_size=7, use_dropout=False, pretrained=True):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, avgpool_size=avgpool_size,
                   use_dropout=use_dropout)

    if pretrained:
        state_dict = resnet.model_zoo.load_url(resnet.model_urls['resnet101'])

        current_state = model.state_dict()
        keys = list(state_dict.keys())
        for key in keys:
            if not key.startswith('fc.'):
                current_state[key] = state_dict[key]

        model.load_state_dict(current_state)
    return model 
開發者ID:Britefury,項目名稱:self-ensemble-visual-domain-adapt-photo,代碼行數:22,代碼來源:network_architectures.py

示例8: resnet152

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet152(num_classes=1000, avgpool_size=7, use_dropout=False, pretrained=True):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(resnet.Bottleneck, [3, 8, 36, 3], num_classes=num_classes, avgpool_size=avgpool_size,
                   use_dropout=use_dropout)

    if pretrained:
        state_dict = resnet.model_zoo.load_url(resnet.model_urls['resnet152'])

        current_state = model.state_dict()
        keys = list(state_dict.keys())
        for key in keys:
            if not key.startswith('fc.'):
                current_state[key] = state_dict[key]

        model.load_state_dict(current_state)
    return model 
開發者ID:Britefury,項目名稱:self-ensemble-visual-domain-adapt-photo,代碼行數:22,代碼來源:network_architectures.py

示例9: trinet

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def trinet(**kwargs):
    """Creates a TriNet network and loads the pretrained ResNet50 weights.
    
    https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2
    """


    model = TriNet(Bottleneck, [3, 4, 6, 3], **kwargs)

    pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    model_dict = model.state_dict()

    # filter out fully connected keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("fc")}
    #for key, value in pretrained_dict.items():
    #    print(key)

    # overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # load the new state dict
    model.load_state_dict(model_dict)
    endpoints = {}
    endpoints["emb"] = None
    return model, endpoints 
開發者ID:kilsenp,項目名稱:triplet-reid-pytorch,代碼行數:26,代碼來源:trinet.py

示例10: stride_test

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def stride_test(**kwargs):


    model = StrideTest(Bottleneck, [3, 4, 6, 3], **kwargs)
    pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
    model_dict = model.state_dict()

    # filter out fully connected keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("fc")}
    pretrained_dict = {k: v for k, v in pretrained_dict.items() 
                       if  not (k.startswith("layer4") and "downsample" in k)}

    #pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith("layer4.0")}
    #for key, value in pretrained_dict.items():
    #    print(key)

    # overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # load the new state dict
    model.load_state_dict(model_dict)
    endpoints = {}
    endpoints["emb"] = None
    return model 
開發者ID:kilsenp,項目名稱:triplet-reid-pytorch,代碼行數:25,代碼來源:dilated.py

示例11: resnet18

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet18(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, BasicBlock, [2, 2, 2, 2], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet18']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
開發者ID:ruiminshen,項目名稱:yolo2-pytorch,代碼行數:13,代碼來源:resnet.py

示例12: resnet34

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet34(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, BasicBlock, [3, 4, 6, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet34']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
開發者ID:ruiminshen,項目名稱:yolo2-pytorch,代碼行數:13,代碼來源:resnet.py

示例13: resnet50

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet50(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 6, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet50']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
開發者ID:ruiminshen,項目名稱:yolo2-pytorch,代碼行數:13,代碼來源:resnet.py

示例14: resnet101

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet101(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 4, 23, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet101']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
開發者ID:ruiminshen,項目名稱:yolo2-pytorch,代碼行數:13,代碼來源:resnet.py

示例15: resnet152

# 需要導入模塊: from torchvision.models import resnet [as 別名]
# 或者: from torchvision.models.resnet import model_urls [as 別名]
def resnet152(config_channels, anchors, num_cls, **kwargs):
    model = ResNet(config_channels, anchors, num_cls, Bottleneck, [3, 8, 36, 3], **kwargs)
    if config_channels.config.getboolean('model', 'pretrained'):
        url = _model.model_urls['resnet152']
        logging.info('use pretrained model: ' + url)
        state_dict = model.state_dict()
        for key, value in model_zoo.load_url(url).items():
            if key in state_dict:
                state_dict[key] = value
        model.load_state_dict(state_dict)
    return model 
開發者ID:ruiminshen,項目名稱:yolo2-pytorch,代碼行數:13,代碼來源:resnet.py


注:本文中的torchvision.models.resnet.model_urls方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。