当前位置: 首页>>代码示例>>Python>>正文


Python densenet.densenet121方法代码示例

本文整理汇总了Python中torchvision.models.densenet.densenet121方法的典型用法代码示例。如果您正苦于以下问题:Python densenet.densenet121方法的具体用法?Python densenet.densenet121怎么用?Python densenet.densenet121使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.models.densenet的用法示例。


在下文中一共展示了densenet.densenet121方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_se_densenet

# 需要导入模块: from torchvision.models import densenet [as 别名]
# 或者: from torchvision.models.densenet import densenet121 [as 别名]
def test_se_densenet(pretrained=False):
    X = torch.Tensor(32, 3, 224, 224)

    if pretrained:
        model = se_densenet121(pretrained=pretrained)
        net_state_dict = {key: value for key, value in model_zoo.load_url("https://download.pytorch.org/models/densenet121-a639ec97.pth").items()}
        model.load_state_dict(net_state_dict, strict=False)

    else:
        model = se_densenet121(pretrained=pretrained)

    # print(model)
    if torch.cuda.is_available():
        X = X.cuda()
        model = model.cuda()
    model.eval()
    with torch.no_grad():
        output = model(X)
        print(output.shape) 
开发者ID:zhouyuangan,项目名称:SE_DenseNet,代码行数:21,代码来源:test_se_densenet.py

示例2: test_densenet

# 需要导入模块: from torchvision.models import densenet [as 别名]
# 或者: from torchvision.models.densenet import densenet121 [as 别名]
def test_densenet():
    """create example tensor data for densenet, and print output variable shape"""
    X = torch.Tensor(32, 3, 224, 224)

    model = densenet121(pretrained=False)
    
    if torch.cuda.is_available():
        model = model.cuda()
        X = X.cuda()

    model.eval()
    with torch.no_grad():
        output = model(X)
        print(output.shape) 
开发者ID:zhouyuangan,项目名称:SE_DenseNet,代码行数:16,代码来源:test_se_densenet.py

示例3: __init__

# 需要导入模块: from torchvision.models import densenet [as 别名]
# 或者: from torchvision.models.densenet import densenet121 [as 别名]
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, pretrained=True):

        super(DenseNet, self).__init__()

        # First convolution
        self.start_features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features

        init_weights = list(densenet121(pretrained=True).features.children())
        start = 0
        for i, c in enumerate(self.start_features.children()):
            if pretrained:
                c.load_state_dict(init_weights[i].state_dict())
            start += 1
        self.blocks = nn.ModuleList()
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            if pretrained:
                block.load_state_dict(init_weights[start].state_dict())
            start += 1
            self.blocks.append(block)
            setattr(self, 'denseblock%d' % (i + 1), block)

            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                downsample = i < 1
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
                                    downsample=downsample)
                if pretrained:
                    trans.load_state_dict(init_weights[start].state_dict())
                start += 1
                self.blocks.append(trans)
                setattr(self, 'transition%d' % (i + 1), trans)
                num_features = num_features // 2 
开发者ID:hyk1996,项目名称:Single-Human-Parsing-LIP,代码行数:45,代码来源:extractors.py

示例4: create_model

# 需要导入模块: from torchvision.models import densenet [as 别名]
# 或者: from torchvision.models.densenet import densenet121 [as 别名]
def create_model(model_name, num_classes=1000, pretrained=False, **kwargs):
    if 'test_time_pool' in kwargs:
        test_time_pool = kwargs.pop('test_time_pool')
    else:
        test_time_pool = True
    if model_name == 'dpn68':
        model = dpn68(
            pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
    elif model_name == 'dpn68b':
        model = dpn68b(
            pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
    elif model_name == 'dpn92':
        model = dpn92(
            pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
    elif model_name == 'dpn98':
        model = dpn98(
            pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
    elif model_name == 'dpn131':
        model = dpn131(
            pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
    elif model_name == 'dpn107':
        model = dpn107(
            pretrained=pretrained, test_time_pool=test_time_pool, num_classes=num_classes)
    elif model_name == 'resnet18':
        model = resnet18(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'resnet34':
        model = resnet34(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'resnet50':
        model = resnet50(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'resnet101':
        model = resnet101(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'resnet152':
        model = resnet152(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'densenet121':
        model = densenet121(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'densenet161':
        model = densenet161(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'densenet169':
        model = densenet169(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'densenet201':
        model = densenet201(pretrained=pretrained, num_classes=num_classes, **kwargs)
    elif model_name == 'inception_v3':
        model = inception_v3(
            pretrained=pretrained, num_classes=num_classes, transform_input=False, **kwargs)
    else:
        assert False, "Unknown model architecture (%s)" % model_name
    return model 
开发者ID:rwightman,项目名称:pytorch-dpn-pretrained,代码行数:49,代码来源:model_factory.py


注:本文中的torchvision.models.densenet.densenet121方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。