本文整理汇总了Python中torchvision.models.__dict__方法的典型用法代码示例。如果您正苦于以下问题:Python models.__dict__方法的具体用法?Python models.__dict__怎么用?Python models.__dict__使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torchvision.models
的用法示例。
在下文中一共展示了models.__dict__方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_cnn
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_cnn(self, arch, pretrained):
"""Load a pretrained CNN and parallelize over GPUs
"""
if pretrained:
print(("=> using pre-trained model '{}'".format(arch)))
model = models.__dict__[arch](pretrained=True)
else:
print(("=> creating model '{}'".format(arch)))
model = models.__dict__[arch]()
if arch.startswith('alexnet') or arch.startswith('vgg'):
model.features = nn.DataParallel(model.features)
model.cuda()
else:
model = nn.DataParallel(model).cuda()
return model
示例2: get_cnn
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_cnn(self, arch, pretrained):
"""Load a pretrained CNN and parallelize over GPUs
"""
if pretrained:
print("=> using pre-trained model '{}'".format(arch))
model = models.__dict__[arch](pretrained=True)
else:
print("=> creating model '{}'".format(arch))
model = models.__dict__[arch]()
if arch.startswith('alexnet') or arch.startswith('vgg'):
model.features = nn.DataParallel(model.features)
model.cuda()
else:
model = nn.DataParallel(model).cuda()
return model
示例3: initialize_model
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def initialize_model(
arch: str, lr: float, momentum: float, weight_decay: float, device_id: int
):
print(f"=> creating model: {arch}")
model = models.__dict__[arch]()
# For multiprocessing distributed, DistributedDataParallel constructor
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
model.cuda(device_id)
cudnn.benchmark = True
model = DistributedDataParallel(model, device_ids=[device_id])
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(device_id)
optimizer = SGD(
model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
)
return model, criterion, optimizer
示例4: build_model
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def build_model(self):
# Create model from scratch or use a pretrained one
print("=> using model '{}'".format(self._arch))
self._model = models.__dict__[self._arch](num_classes=len(self._labels))
print("=> loading checkpoint '{}'".format(self._ckp))
if self._cuda:
checkpoint = torch.load(self._ckp)
else:
# Load GPU model on CPU
checkpoint = torch.load(self._ckp, map_location=lambda storage, loc: storage)
# Load weights
self._model.load_state_dict(checkpoint['state_dict'])
if self._cuda:
self._model.cuda()
else:
self._model.cpu()
# Preprocess Images to be ImageNet-compliant
示例5: get_model
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_model(args):
parse_model(args)
if args.dataset == 'imagenet':
model = torch_models.__dict__[args.model]()
args.model_name = args.model
elif args.basic_model:
model = cifar_models.BasicConvNet(args.dataset, args.planes)
args.model_name = 'convnet_{}'.format(args.planes)
else:
model = cifar_models.DenseNet3(args.depth, args.num_classes, args.growth)
args.model_name = 'densenet_{}_{}'.format(args.depth, args.growth)
# Print the number of model parameters
nparams = sum([p.data.nelement() for p in model.parameters()])
print('Number of model parameters: \t {}'.format(nparams))
return model
示例6: test_coordConvNet
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def test_coordConvNet(input_image):
print('- CoordConvNet')
device = input_image.device
import torchvision.models as models
vgg16 = models.__dict__['vgg16'](pretrained=False)
print('VGG16 :\n', vgg16)
vgg16 = CoordConvNet(vgg16, with_r=True)
print('CoordVGG16 :\n', vgg16)
vgg16 = vgg16.to(device)
output = vgg16(input_image)
print('Input Size : ', input_image.size())
print('Output Size : ', [i.size() for i in output])
print('- CoordConvNet: OK!')
示例7: download
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def download(cls, architecture, path="./"):
if cls.sanity_check(architecture):
architecture_file = path + "imagenet_{}.pth".format(architecture)
if not os.path.exists(architecture_file):
kwargs = {}
if architecture == 'inception_v3':
kwargs['transform_input'] = False
model = models.__dict__[architecture](pretrained=True, **kwargs)
torch.save(model, architecture_file)
print("PyTorch pretrained model is saved as [{}].".format(architecture_file))
else:
print("File [{}] existed!".format(architecture_file))
return architecture_file
else:
return None
示例8: __init__
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def __init__(self,
arch,
pretrained,
lr: float,
momentum: float,
weight_decay: int,
data_path: str,
batch_size: int, **kwargs):
"""
TODO: add docstring here
"""
super().__init__()
self.arch = arch
self.pretrained = pretrained
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.data_path = data_path
self.batch_size = batch_size
self.model = models.__dict__[self.arch](pretrained=self.pretrained)
示例9: generic_load
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def generic_load(arch, pretrained, weights, args):
if arch in tmodels.__dict__: # torchvision models
if pretrained:
print("=> using pre-trained model '{}'".format(arch))
model = tmodels.__dict__[arch](pretrained=True)
model = model.cuda()
else:
print("=> creating model '{}'".format(arch))
model = tmodels.__dict__[arch]()
else: # defined as script in this directory
model = importlib.import_module('.' + arch, package='models')
model = model.__dict__[arch](args)
if not weights == '':
print('loading pretrained-weights from {}'.format(weights))
chkpoint = torch.load(weights)
if isinstance(chkpoint, dict) and 'state_dict' in chkpoint:
chkpoint = chkpoint['state_dict']
load_partial_state(model, chkpoint)
return model
示例10: load_model
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def load_model(modelID, categories):
if modelID == 1:
model_name = 'resnet50_imagenetpretrained_moments'
weight_file = 'moments_RGB_resnet50_imagenetpretrained.pth.tar'
if not os.access(weight_file, os.W_OK):
weight_url = 'http://moments.csail.mit.edu/moments_models/' + weight_file
os.system('wget ' + weight_url)
model = models.__dict__['resnet50'](num_classes=len(categories))
useGPU = 0
if useGPU == 1:
checkpoint = torch.load(weight_file)
else:
checkpoint = torch.load(weight_file, map_location=lambda storage, loc: storage) # allow cpu
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model.eval()
# hook the feature extractor
features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
for name in features_names:
model._modules.get(name).register_forward_hook(hook_feature)
return model
示例11: load_model
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def load_model(modelID, categories):
if modelID == 1:
weight_file = 'moments_RGB_resnet50_imagenetpretrained.pth.tar'
if not os.access(weight_file, os.W_OK):
weight_url = 'http://moments.csail.mit.edu/moments_models/' + weight_file
os.system('wget ' + weight_url)
model = models.__dict__['resnet50'](num_classes=len(categories))
useGPU = 0
if useGPU == 1:
checkpoint = torch.load(weight_file)
else:
checkpoint = torch.load(weight_file, map_location=lambda storage,
loc: storage) # allow cpu
state_dict = {str.replace(str(k), 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict)
model.eval()
return model
示例12: get_cnn
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def get_cnn(self, arch, pretrained, fusion):
"""Load a pretrained CNN and parallelize over GPUs
"""
if arch == "resnet152":
if pretrained:
print("=> using pre-trained model '{}'".format(arch))
model = resnet152(pretrained=True, fusion=fusion)
else:
print("=> creating model '{}'".format(arch))
model = resnet152(pretrained=False, fusion=fusion)
else:
if pretrained:
print("=> using pre-trained model '{}'".format(arch))
model = models.__dict__[arch](pretrained=True)
else:
print("=> creating model '{}'".format(arch))
model = models.__dict__[arch]()
return model
示例13: setupRun
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def setupRun(self, state, arg):
arch, sizes = arg[("arch", "size")]
batch_size, c, h, w = sizes[0], sizes[1], sizes[2], sizes[3]
batch_size = 1 if arg.single_batch_size else batch_size
data_ = torch.randn(batch_size, c, h, w)
target_ = torch.arange(1, batch_size + 1).long()
state.net = models.__dict__[
arch
]() # no need to load pre-trained weights for dummy data
state.optimizer = optim.SGD(state.net.parameters(), lr=0.01)
state.criterion = nn.CrossEntropyLoss()
state.net.eval()
state.data, state.target = Variable(data_), Variable(target_)
state.steps = 0
state.time_fwd = 0
state.time_bwd = 0
state.time_upt = 0
示例14: __str__
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
示例15: main
# 需要导入模块: from torchvision import models [as 别名]
# 或者: from torchvision.models import __dict__ [as 别名]
def main():
args = parser.parse_args()
dir_name = args.arch + '_' + datetime.datetime.now().strftime('%m%d_%H%M')
checkpoint_dir = os.path.join('checkpoints', os.path.join('coding', dir_name))
os.makedirs(checkpoint_dir)
print("=" * 89)
print("=> creating model '{}'".format(args.arch))
if args.arch.startswith('inception'):
model = models.__dict__[args.arch](transform_input=True)
else:
model = models.__dict__[args.arch]()
if args.pretrained:
if os.path.isfile(args.pretrained):
print("=> using pre-trained model '{}'".format(args.pretrained))
checkpoint = torch.load(args.pretrained)
model = Codec.decode(model=model, state_dict=checkpoint['state_dict'])
torch.save({
'state_dict': model.state_dict(),
}, os.path.join(checkpoint_dir, 'decode.pth.tar'), pickle_protocol=4)
else:
print("=> no checkpoint found at '{}'".format(args.pretrained))
else:
print("=> no checkpoint")
print("=" * 89)