本文整理汇总了Python中torch.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了load函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _load
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
示例2: init_model
def init_model(word2id, opt):
model = Seq2SeqLSTMAttention(
emb_dim=opt.word_vec_size,
vocab_size=opt.vocab_size,
src_hidden_dim=opt.rnn_size,
trg_hidden_dim=opt.rnn_size,
ctx_hidden_dim=opt.rnn_size,
attention_mode='dot',
batch_size=opt.batch_size,
bidirectional=opt.bidirectional,
pad_token_src = word2id[pykp.io.PAD_WORD],
pad_token_trg = word2id[pykp.io.PAD_WORD],
nlayers_src=opt.enc_layers,
nlayers_trg=opt.dec_layers,
dropout=opt.dropout,
teacher_forcing_ratio=opt.teacher_forcing_ratio,
scheduled_sampling=opt.scheduled_sampling,
scheduled_sampling_batches=opt.scheduled_sampling_batches
)
logging.info('====================== Model Parameters =========================')
if opt.train_from:
logging.info("loading previous checkpoint from %s" % opt.train_from)
if torch.cuda.is_available():
model.load_state_dict(torch.load(open(opt.train_from, 'rb')))
else:
model.load_state_dict(torch.load(
open(opt.train_from, 'rb'), map_location=lambda storage, loc: storage
))
utils.tally_parameters(model)
return model
示例3: generate
def generate(**kwargs):
"""
随机生成动漫头像,并根据netd的分数选择较好的
"""
for k_, v_ in kwargs.items():
setattr(opt, k_, v_)
device=t.device('cuda') if opt.gpu else t.device('cpu')
netg, netd = NetG(opt).eval(), NetD(opt).eval()
noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
noises = noises.to(device)
map_location = lambda storage, loc: storage
netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
netd.to(device)
netg.to(device)
# 生成图片,并计算图片在判别器的分数
fake_img = netg(noises)
scores = netd(fake_img).detach()
# 挑选最好的某几张
indexs = scores.topk(opt.gen_num)[1]
result = []
for ii in indexs:
result.append(fake_img.data[ii])
# 保存图片
tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))
示例4: load_model
def load_model(self):
if len(glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')) == 0:
return
if args.load_iter is None:
f_list = glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')
iter_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
start_iter = sorted(iter_list)[-1]
else:
start_iter = args.load_iter
name = args.corpus + '-selector-{}.pth'.format(start_iter)
model_file_path = os.path.join(args.save_dir, name)
print("loading model", model_file_path)
if opt.device == torch.device('cuda'):
state = torch.load(model_file_path)
else:
state = torch.load(model_file_path, map_location=opt.device)
self._epoch = state['epoch']
self._iter = state['iter']
self.running_avg_loss = state['current_loss']
self.min_loss = state['min_loss']
self.model.sentence_selector.load_state_dict(state['selector_state_dict'])
if not args.is_coverage:
self.optimizer.load_state_dict(state['optimizer'])
if opt.device == torch.device('cuda'):
for state in list(self.optimizer.state.values()):
for k, v in list(state.items()):
if torch.is_tensor(v):
state[k] = v.cuda()
示例5: load_checkpoint
def load_checkpoint(checkpoint):
if torch.cuda.is_available():
checkpoint = torch.load(checkpoint)
else:
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
return checkpoint
示例6: run
def run(args, run_args, rank=0, world_size=1):
set_seed(args, rank=rank)
logger = initialize_logger(args, rank)
field, train_sets, val_sets, save_dict = run_args
logger.start = time.time()
logger.info(f'Preparing iterators')
train_iters = [(name, to_iter(args, world_size, tok, x, token_testing=args.token_testing))
for name, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
val_iters = [(name, to_iter(args, world_size, tok, x, train=False, token_testing=args.token_testing, sort=False if 'sql' in name else None))
for name, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)]
logger.info(f'Initializing Writer')
writer = SummaryWriter(log_dir=args.log_dir)
model = init_model(args, field, logger, world_size)
opt = init_opt(args, model)
start_iteration = 1
if save_dict is not None:
logger.info(f'Loading model from {os.path.join(args.save, args.load)}')
save_dict = torch.load(os.path.join(args.save, args.load))
model.load_state_dict(save_dict['model_state_dict'])
if args.resume:
logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')
opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')))
start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1])
logger.info(f'Begin Training')
train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters,
rank=rank, world_size=world_size,
log_every=args.log_every, val_every=args.val_every, rounds=len(train_iters)>1,
writer=writer if rank==0 else None, save_every=args.save_every, start_iteration=start_iteration)
示例7: restore_model
def restore_model(self, resume_iters):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(resume_iters))
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
示例8: get_pretrained_net
def get_pretrained_net(name):
"""Loads pretrained network"""
if name == 'alexnet_caffe':
if not os.path.exists('alexnet-torch_py3.pth'):
print('Downloading AlexNet')
os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
return torch.load('alexnet-torch_py3.pth')
elif name == 'vgg19_caffe':
if not os.path.exists('vgg19-caffe-py3.pth'):
print('Downloading VGG-19')
os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')
vgg = get_vgg19_caffe()
return vgg
elif name == 'vgg16_caffe':
if not os.path.exists('vgg16-caffe-py3.pth'):
print('Downloading VGG-16')
os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')
vgg = get_vgg16_caffe()
return vgg
elif name == 'vgg19_pytorch_modified':
# os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')
model = VGGModified(vgg19(pretrained=False), 0.2)
model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])
return model
else:
assert False
示例9: get_vanilla_vgg_features
def get_vanilla_vgg_features(cut_idx=-1):
if not os.path.exists('vgg_features.pth'):
os.system(
'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth')
vgg_weights = torch.load('vgg19-d01eb7cb.pth')
# fix compatibility issues
map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'}
vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()])
model = models.vgg19()
model.classifier = nn.Sequential(View(), *model.classifier._modules.values())
model.load_state_dict(vgg_weights)
torch.save(model.features, 'vgg_features.pth')
torch.save(model.classifier, 'vgg_classifier.pth')
vgg = torch.load('vgg_features.pth')
if cut_idx > 36:
vgg_classifier = torch.load('vgg_classifier.pth')
vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values()))
vgg.eval()
return vgg
示例10: load
def load(self, filename, legacy=False, ignore_d=False):
"""
ignore_d: if `True`, then don't load in the
discriminator.
"""
if not self.use_cuda:
map_location = lambda storage, loc: storage
else:
map_location = None
if legacy:
g, d = torch.load(filename,
map_location=map_location)
self.g.load_state_dict(g)
if not ignore_d:
self.d.load_state_dict(d)
else:
dd = torch.load(filename,
map_location=map_location)
self.g.load_state_dict(dd['g'])
if not ignore_d:
self.d.load_state_dict(dd['d'])
for key in self.optim:
if ignore_d and key == 'd':
continue
self.optim[key].load_state_dict(dd['optim_'+key])
self.last_epoch = dd['epoch']
示例11: load_network_stageI
def load_network_stageI(self):
from model import STAGE1_G, STAGE1_D
netG = STAGE1_G()
netG.apply(weights_init)
print(netG)
netD = STAGE1_D()
netD.apply(weights_init)
print(netD)
if cfg.NET_G != '':
state_dict = \
torch.load(cfg.NET_G,
map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load from: ', cfg.NET_G)
if cfg.NET_D != '':
state_dict = \
torch.load(cfg.NET_D,
map_location=lambda storage, loc: storage)
netD.load_state_dict(state_dict)
print('Load from: ', cfg.NET_D)
if cfg.CUDA:
netG.cuda()
netD.cuda()
return netG, netD
示例12: __init__
def __init__(self,
root, mnist_root="data",
train=True,
transform=None, target_transform=None,
download=False):
"""Init MNIST-M dataset."""
super(MNISTM, self).__init__()
self.root = os.path.expanduser(root)
self.mnist_root = os.path.expanduser(mnist_root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = \
torch.load(os.path.join(self.root,
self.processed_folder,
self.training_file))
else:
self.test_data, self.test_labels = \
torch.load(os.path.join(self.root,
self.processed_folder,
self.test_file))
示例13: load_models
def load_models(load_path):
model_args = json.load(open("{}/args.json".format(load_path), "r"))
word2idx = json.load(open("{}/vocab.json".format(load_path), "r"))
idx2word = {v: k for k, v in word2idx.items()}
autoencoder = Seq2Seq(emsize=model_args['emsize'],
nhidden=model_args['nhidden'],
ntokens=model_args['ntokens'],
nlayers=model_args['nlayers'],
hidden_init=model_args['hidden_init'])
gan_gen = MLP_G(ninput=model_args['z_size'],
noutput=model_args['nhidden'],
layers=model_args['arch_g'])
gan_disc = MLP_D(ninput=model_args['nhidden'],
noutput=1,
layers=model_args['arch_d'])
print('Loading models from'+load_path)
ae_path = os.path.join(load_path, "autoencoder_model.pt")
gen_path = os.path.join(load_path, "gan_gen_model.pt")
disc_path = os.path.join(load_path, "gan_disc_model.pt")
autoencoder.load_state_dict(torch.load(ae_path))
gan_gen.load_state_dict(torch.load(gen_path))
gan_disc.load_state_dict(torch.load(disc_path))
return model_args, idx2word, autoencoder, gan_gen, gan_disc
示例14: demo
def demo(data, save, depth=40, growth_rate=12, batch_size=256):
"""
Applies temperature scaling to a trained model.
Takes a pretrained DenseNet-CIFAR100 model, and a validation set
(parameterized by indices on train set).
Applies temperature scaling, and saves a temperature scaled version.
NB: the "save" parameter references a DIRECTORY, not a file.
In that directory, there should be two files:
- model.pth (model state dict)
- valid_indices.pth (a list of indices corresponding to the validation set).
data (str) - path to directory where data should be loaded from/downloaded
save (str) - directory with necessary files (see above)
"""
# Load model state dict
model_filename = os.path.join(save, 'model.pth')
if not os.path.exists(model_filename):
raise RuntimeError('Cannot find file %s to load' % model_filename)
state_dict = torch.load(model_filename)
# Load validation indices
valid_indices_filename = os.path.join(save, 'valid_indices.pth')
if not os.path.exists(valid_indices_filename):
raise RuntimeError('Cannot find file %s to load' % valid_indices_filename)
valid_indices = torch.load(valid_indices_filename)
# Regenerate validation set loader
mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]
test_transforms = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=mean, std=stdv),
])
valid_set = tv.datasets.CIFAR100(data, train=True, transform=test_transforms, download=True)
valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
sampler=SubsetRandomSampler(valid_indices))
# Load original model
if (depth - 4) % 3:
raise Exception('Invalid depth')
block_config = [(depth - 4) // 6 for _ in range(3)]
orig_model = DenseNetEfficientMulti(
growth_rate=growth_rate,
block_config=block_config,
num_classes=100
).cuda()
orig_model.load_state_dict(state_dict)
# Now we're going to wrap the model with a decorator that adds temperature scaling
model = ModelWithTemperature(orig_model)
# Tune the model temperature, and save the results
model.set_temperature(valid_loader)
model_filename = os.path.join(save, 'model_with_temperature.pth')
torch.save(model.state_dict(), model_filename)
print('Temperature scaled model sved to %s' % model_filename)
print('Done!')
示例15: __init__
def __init__(self, file, labelFile):
self.train = torch.load(file)
self.label = torch.load(labelFile)
self.len = len(self.train) # get how many data points.
for i in range(0, self.len): # transform the imgs.
self.train[i] = transforms.Normalize((0.1307,), (0.3081,))(
self.train[i].view(1, -1)) # do a small transformation
self.train = self.train.view(-1, 1, 28, 28)