當前位置: 首頁>>代碼示例>>Python>>正文


Python resnet.ResNet50方法代碼示例

本文整理匯總了Python中resnet.ResNet50方法的典型用法代碼示例。如果您正苦於以下問題:Python resnet.ResNet50方法的具體用法?Python resnet.ResNet50怎麽用?Python resnet.ResNet50使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在resnet的用法示例。


在下文中一共展示了resnet.ResNet50方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: run

# 需要導入模塊: import resnet [as 別名]
# 或者: from resnet import ResNet50 [as 別名]
def run():
    t = time.time()
    print('net_cache : ', args.net_cache)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    model = ResNet50()
    model = nn.DataParallel(model.cuda())

    if os.path.exists(args.net_cache):
        print('loading checkpoint {} ..........'.format(args.net_cache))
        checkpoint = torch.load(args.net_cache)
        best_top1_acc = checkpoint['best_top1_acc']
        model.load_state_dict(checkpoint['state_dict'])
        #print("loaded checkpoint {} epoch = {}" .format(args.net_cache, checkpoint['epoch']))

    else:
        print('can not find {} '.format(args.net_cache))
        return

    num_states = len(stage_repeat) + sum(stage_repeat)
    search(model, criterion, num_states)

    total_searching_time = time.time() - t
    print('total searching time = {:.2f} hours'.format(total_searching_time/3600), flush=True) 
開發者ID:liuzechun,項目名稱:MetaPruning,代碼行數:27,代碼來源:search.py

示例2: __init__

# 需要導入模塊: import resnet [as 別名]
# 或者: from resnet import ResNet50 [as 別名]
def __init__(self, dropout_rate, feat_length = 512, archi_type='resnet18'):
        super(CIFAR10FeatureLayer, self).__init__()
        self.archi_type = archi_type
        self.feat_length = feat_length
        if self.archi_type == 'default':
            self.add_module('conv1', nn.Conv2d(3, 32, kernel_size=3, padding=1))
            self.add_module('bn1', nn.BatchNorm2d(32))
            self.add_module('relu1', nn.ReLU())
            self.add_module('pool1', nn.MaxPool2d(kernel_size=2))
            #self.add_module('drop1', nn.Dropout(dropout_rate))
            self.add_module('conv2', nn.Conv2d(32, 32, kernel_size=3, padding=1))
            self.add_module('bn2', nn.BatchNorm2d(32))
            self.add_module('relu2', nn.ReLU())
            self.add_module('pool2', nn.MaxPool2d(kernel_size=2))
            #self.add_module('drop2', nn.Dropout(dropout_rate))
            self.add_module('conv3', nn.Conv2d(32, 64, kernel_size=3, padding=1))
            self.add_module('bn3', nn.BatchNorm2d(64))
            self.add_module('relu3', nn.ReLU())
            self.add_module('pool3', nn.MaxPool2d(kernel_size=2))
            #self.add_module('drop3', nn.Dropout(dropout_rate))
        elif self.archi_type == 'resnet18':
            self.add_module('resnet18', resnet.ResNet18(feat_length))
        elif self.archi_type == 'resnet50':
            self.add_module('resnet50', resnet.ResNet50(feat_length))            
        elif self.archi_type == 'resnet152':
            self.add_module('resnet152', resnet.ResNet152(feat_length))  
        else:
            raise NotImplementedError 
開發者ID:Nicholasli1995,項目名稱:VisualizingNDF,代碼行數:30,代碼來源:ndf.py

示例3: model_factory

# 需要導入模塊: import resnet [as 別名]
# 或者: from resnet import ResNet50 [as 別名]
def model_factory(model_name, **params):
    model_dict = {
        'densenet121': DenseNet121,
        'densenet169': DenseNet169,
        'densenet201': DenseNet201,
        'densenet161': DenseNet161,
        'densenet-cifar': densenet_cifar,
        'dual-path-net-26': DPN26,
        'dual-path-net-92': DPN92,
        'googlenet': GoogLeNet,
        'lenet': LeNet,
        'mobilenet': MobileNet,
        'mobilenetv2': MobileNetV2,
        'pnasneta': PNASNetA,
        'pnasnetb': PNASNetB,
        'preact-resnet18': PreActResNet18,
        'preact-resnet34': PreActResNet34,
        'preact-resnet50': PreActResNet50,
        'preact-resnet101': PreActResNet101,
        'preact-resnet152': PreActResNet152,
        'resnet18': ResNet18,
        'resnet34': ResNet34,
        'resnet50': ResNet50,
        'resnet101': ResNet101,
        'resnet152': ResNet152,
        'resnext29_2x64d': ResNeXt29_2x64d,
        'resnext29_4x64d': ResNeXt29_4x64d,
        'resnext29_8x64d': ResNeXt29_8x64d,
        'resnext29_32x64d': ResNeXt29_32x4d,
        'senet18': SENet18,
        'shufflenetg2': ShuffleNetG2,
        'shufflenetg3': ShuffleNetG3,
        'shufflenetv2_0.5': ShuffleNetV2,
        'shufflenetv2_1.0': ShuffleNetV2,
        'shufflenetv2_1.5': ShuffleNetV2,
        'shufflenetv2_2.0': ShuffleNetV2,
        'vgg11': VGG,
        'vgg13': VGG,
        'vgg16': VGG,
        'vgg19': VGG,
    }

    if 'vgg' in model_name:
        return model_dict[model_name](model_name)
    elif 'shufflenetv2' in model_name:
        return model_dict[model_name](float(model_name[-3:]))
    elif model_name in model_dict.keys():
        return model_dict[model_name]()
    else:
        raise AttributeError('Model doesn\'t exist') 
開發者ID:suvojit-0x55aa,項目名稱:mixed-precision-pytorch,代碼行數:52,代碼來源:model_factory_dict.py


注:本文中的resnet.ResNet50方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。