当前位置: 首页>>代码示例>>Python>>正文


Python models.__dict__方法代码示例

本文整理汇总了Python中torchvision.models.__dict__方法的典型用法代码示例。如果您正苦于以下问题:Python models.__dict__方法的具体用法?Python models.__dict__怎么用?Python models.__dict__使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torchvision.models的用法示例。


在下文中一共展示了models.__dict__方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_cnn

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_cnn(self, arch, pretrained):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if pretrained:
            print(("=> using pre-trained model '{}'".format(arch)))
            model = models.__dict__[arch](pretrained=True)
        else:
            print(("=> creating model '{}'".format(arch)))
            model = models.__dict__[arch]()

        if arch.startswith('alexnet') or arch.startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        return model 
开发者ID:ExplorerFreda,项目名称:VSE-C,代码行数:19,代码来源:model.py

示例2: get_cnn

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_cnn(self, arch, pretrained):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if pretrained:
            print("=> using pre-trained model '{}'".format(arch))
            model = models.__dict__[arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(arch))
            model = models.__dict__[arch]()

        if arch.startswith('alexnet') or arch.startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        return model 
开发者ID:fartashf,项目名称:vsepp,代码行数:19,代码来源:model.py

示例3: initialize_model

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def initialize_model(
    arch: str, lr: float, momentum: float, weight_decay: float, device_id: int
):
    print(f"=> creating model: {arch}")
    model = models.__dict__[arch]()
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    model.cuda(device_id)
    cudnn.benchmark = True
    model = DistributedDataParallel(model, device_ids=[device_id])
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(device_id)
    optimizer = SGD(
        model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
    )
    return model, criterion, optimizer 
开发者ID:pytorch,项目名称:elastic,代码行数:19,代码来源:main.py

示例4: build_model

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def build_model(self):
		# Create model from scratch or use a pretrained one
		print("=> using model '{}'".format(self._arch))
		self._model = models.__dict__[self._arch](num_classes=len(self._labels))
		print("=> loading checkpoint '{}'".format(self._ckp))
		if self._cuda:
			checkpoint = torch.load(self._ckp)
		else:
			# Load GPU model on CPU
			checkpoint = torch.load(self._ckp, map_location=lambda storage, loc: storage)
		# Load weights
		self._model.load_state_dict(checkpoint['state_dict'])

		if self._cuda:
			self._model.cuda()
		else:
			self._model.cpu()


	# Preprocess Images to be ImageNet-compliant 
开发者ID:floydhub,项目名称:imagenet,代码行数:22,代码来源:imagenet_models.py

示例5: get_model

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_model(args):
    parse_model(args)

    if args.dataset == 'imagenet':
        model = torch_models.__dict__[args.model]()
        args.model_name = args.model
    elif args.basic_model:
        model = cifar_models.BasicConvNet(args.dataset, args.planes)
        args.model_name = 'convnet_{}'.format(args.planes)
    else:
        model = cifar_models.DenseNet3(args.depth, args.num_classes, args.growth)
        args.model_name = 'densenet_{}_{}'.format(args.depth, args.growth)

    # Print the number of model parameters
    nparams = sum([p.data.nelement() for p in model.parameters()])
    print('Number of model parameters: \t {}'.format(nparams))

    return model 
开发者ID:oval-group,项目名称:smooth-topk,代码行数:20,代码来源:main.py

示例6: test_coordConvNet

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def test_coordConvNet(input_image):

    print('- CoordConvNet')

    device = input_image.device

    import torchvision.models as models

    vgg16 = models.__dict__['vgg16'](pretrained=False)

    print('VGG16 :\n', vgg16)

    vgg16 = CoordConvNet(vgg16, with_r=True)

    print('CoordVGG16 :\n', vgg16)

    vgg16 = vgg16.to(device)

    output = vgg16(input_image)

    print('Input Size  : ', input_image.size())
    print('Output Size : ', [i.size() for i in output])

    print('- CoordConvNet: OK!') 
开发者ID:Wizaron,项目名称:coord-conv-pytorch,代码行数:26,代码来源:test.py

示例7: download

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def download(cls, architecture, path="./"):
        if cls.sanity_check(architecture):
            architecture_file = path + "imagenet_{}.pth".format(architecture)
            if not os.path.exists(architecture_file):
                kwargs = {}
                if architecture == 'inception_v3':
                    kwargs['transform_input'] = False
                model = models.__dict__[architecture](pretrained=True, **kwargs)
                torch.save(model, architecture_file)
                print("PyTorch pretrained model is saved as [{}].".format(architecture_file))
            else:
                print("File [{}] existed!".format(architecture_file))

            return architecture_file

        else:
            return None 
开发者ID:microsoft,项目名称:MMdnn,代码行数:19,代码来源:extractor.py

示例8: __init__

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def __init__(self,
                 arch,
                 pretrained,
                 lr: float,
                 momentum: float,
                 weight_decay: int,
                 data_path: str,
                 batch_size: int, **kwargs):
        """
        TODO: add docstring here
        """
        super().__init__()
        self.arch = arch
        self.pretrained = pretrained
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.data_path = data_path
        self.batch_size = batch_size
        self.model = models.__dict__[self.arch](pretrained=self.pretrained) 
开发者ID:PyTorchLightning,项目名称:pytorch-lightning,代码行数:22,代码来源:imagenet.py

示例9: generic_load

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def generic_load(arch, pretrained, weights, args):
    if arch in tmodels.__dict__:  # torchvision models
        if pretrained:
            print("=> using pre-trained model '{}'".format(arch))
            model = tmodels.__dict__[arch](pretrained=True)
            model = model.cuda()
        else:
            print("=> creating model '{}'".format(arch))
            model = tmodels.__dict__[arch]()
    else:  # defined as script in this directory
        model = importlib.import_module('.' + arch, package='models')
        model = model.__dict__[arch](args)

    if not weights == '':
        print('loading pretrained-weights from {}'.format(weights))
        chkpoint = torch.load(weights)
        if isinstance(chkpoint, dict) and 'state_dict' in chkpoint:
            chkpoint = chkpoint['state_dict']
        load_partial_state(model, chkpoint)
    return model 
开发者ID:gsig,项目名称:actor-observer,代码行数:22,代码来源:utils.py

示例10: load_model

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def load_model(modelID, categories):
    if modelID == 1:
        model_name = 'resnet50_imagenetpretrained_moments'
        weight_file = 'moments_RGB_resnet50_imagenetpretrained.pth.tar'
        if not os.access(weight_file, os.W_OK):
            weight_url = 'http://moments.csail.mit.edu/moments_models/' + weight_file
            os.system('wget ' + weight_url)

        model = models.__dict__['resnet50'](num_classes=len(categories))

        useGPU = 0
        if useGPU == 1:
            checkpoint = torch.load(weight_file)
        else:
            checkpoint = torch.load(weight_file, map_location=lambda storage, loc: storage) # allow cpu

        state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        model.load_state_dict(state_dict)

    model.eval()
    # hook the feature extractor
    features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model 
开发者ID:zhoubolei,项目名称:moments_models,代码行数:27,代码来源:test_model_CAM.py

示例11: load_model

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def load_model(modelID, categories):
    if modelID == 1:
        weight_file = 'moments_RGB_resnet50_imagenetpretrained.pth.tar'
        if not os.access(weight_file, os.W_OK):
            weight_url = 'http://moments.csail.mit.edu/moments_models/' + weight_file
            os.system('wget ' + weight_url)
        model = models.__dict__['resnet50'](num_classes=len(categories))

        useGPU = 0
        if useGPU == 1:
            checkpoint = torch.load(weight_file)
        else:
            checkpoint = torch.load(weight_file, map_location=lambda storage,
                                    loc: storage)  # allow cpu

        state_dict = {str.replace(str(k), 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
        model.load_state_dict(state_dict)

    model.eval()
    return model 
开发者ID:zhoubolei,项目名称:moments_models,代码行数:22,代码来源:test_model.py

示例12: get_cnn

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_cnn(self, arch, pretrained, fusion):
        """Load a pretrained CNN and parallelize over GPUs
        """
        if arch == "resnet152":
            if pretrained:
                print("=> using pre-trained model '{}'".format(arch))
                model = resnet152(pretrained=True, fusion=fusion)
            else:
                print("=> creating model '{}'".format(arch))
                model = resnet152(pretrained=False, fusion=fusion)
        
        else:
            if pretrained:
                print("=> using pre-trained model '{}'".format(arch))
                model = models.__dict__[arch](pretrained=True)
            else:
                print("=> creating model '{}'".format(arch))
                model = models.__dict__[arch]()
    
        return model 
开发者ID:ZihaoWang-CV,项目名称:CAMP_iccv19,代码行数:22,代码来源:model.py

示例13: setupRun

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def setupRun(self, state, arg):
        arch, sizes = arg[("arch", "size")]
        batch_size, c, h, w = sizes[0], sizes[1], sizes[2], sizes[3]
        batch_size = 1 if arg.single_batch_size else batch_size

        data_ = torch.randn(batch_size, c, h, w)
        target_ = torch.arange(1, batch_size + 1).long()
        state.net = models.__dict__[
            arch
        ]()  # no need to load pre-trained weights for dummy data

        state.optimizer = optim.SGD(state.net.parameters(), lr=0.01)
        state.criterion = nn.CrossEntropyLoss()

        state.net.eval()

        state.data, state.target = Variable(data_), Variable(target_)

        state.steps = 0
        state.time_fwd = 0
        state.time_bwd = 0
        state.time_upt = 0 
开发者ID:pytorch,项目名称:benchmark,代码行数:24,代码来源:cpu_convnet_benchmark.py

示例14: __str__

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__) 
开发者ID:tczhangzhi,项目名称:pytorch-distributed,代码行数:5,代码来源:dataparallel.py

示例15: main

# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def main():
    args = parser.parse_args()

    dir_name = args.arch + '_' + datetime.datetime.now().strftime('%m%d_%H%M')
    checkpoint_dir = os.path.join('checkpoints', os.path.join('coding', dir_name))
    os.makedirs(checkpoint_dir)

    print("=" * 89)
    print("=> creating model '{}'".format(args.arch))

    if args.arch.startswith('inception'):
        model = models.__dict__[args.arch](transform_input=True)
    else:
        model = models.__dict__[args.arch]()

    if args.pretrained:
        if os.path.isfile(args.pretrained):
            print("=> using pre-trained model '{}'".format(args.pretrained))
            checkpoint = torch.load(args.pretrained)

            model = Codec.decode(model=model, state_dict=checkpoint['state_dict'])

            torch.save({
                'state_dict': model.state_dict(),
            }, os.path.join(checkpoint_dir, 'decode.pth.tar'), pickle_protocol=4)
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))
    else:
        print("=> no checkpoint")

    print("=" * 89) 
开发者ID:synxlin,项目名称:nn-compression,代码行数:33,代码来源:decode.py


注:本文中的torchvision.models.__dict__方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。