本文整理匯總了Python中torch.load方法的典型用法代碼示例。如果您正苦於以下問題:Python torch.load方法的具體用法?Python torch.load怎麽用?Python torch.load使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch
的用法示例。
在下文中一共展示了torch.load方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: convert
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def convert(src, dst):
"""Convert keys in pycls pretrained RegNet models to mmdet style."""
# load caffe model
regnet_model = torch.load(src)
blobs = regnet_model['model_state']
# convert to pytorch style
state_dict = OrderedDict()
converted_names = set()
for key, weight in blobs.items():
if 'stem' in key:
convert_stem(key, weight, state_dict, converted_names)
elif 'head' in key:
convert_head(key, weight, state_dict, converted_names)
elif key.startswith('s'):
convert_reslayer(key, weight, state_dict, converted_names)
# check if all layers are converted
for key in blobs:
if key not in converted_names:
print(f'not converted: {key}')
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
示例2: get_model
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def get_model(load_weights = True):
deepsea_cpu = nn.Sequential( # Sequential,
nn.Conv2d(4,320,(1, 8),(1, 1)),
nn.Threshold(0, 1e-06),
nn.MaxPool2d((1, 4),(1, 4)),
nn.Dropout(0.2),
nn.Conv2d(320,480,(1, 8),(1, 1)),
nn.Threshold(0, 1e-06),
nn.MaxPool2d((1, 4),(1, 4)),
nn.Dropout(0.2),
nn.Conv2d(480,960,(1, 8),(1, 1)),
nn.Threshold(0, 1e-06),
nn.Dropout(0.5),
Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
nn.Threshold(0, 1e-06),
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
nn.Sigmoid(),
)
if load_weights:
deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
return nn.Sequential(ReCodeAlphabet(), deepsea_cpu)
示例3: from_snapshot
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def from_snapshot(self, sfile, nfile):
print('Restoring model snapshots from {:s}'.format(sfile))
self.net.load_state_dict(torch.load(str(sfile)))
print('Restored.')
# Needs to restore the other hyper-parameters/states for training, (TODO xinlei) I have
# tried my best to find the random states so that it can be recovered exactly
# However the Tensorflow state is currently not available
with open(nfile, 'rb') as fid:
st0 = pickle.load(fid)
cur = pickle.load(fid)
perm = pickle.load(fid)
cur_val = pickle.load(fid)
perm_val = pickle.load(fid)
last_snapshot_iter = pickle.load(fid)
np.random.set_state(st0)
self.data_layer._cur = cur
self.data_layer._perm = perm
self.data_layer_val._cur = cur_val
self.data_layer_val._perm = perm_val
return last_snapshot_iter
開發者ID:Sunarker,項目名稱:Collaborative-Learning-for-Weakly-Supervised-Object-Detection,代碼行數:24,代碼來源:train_val.py
示例4: get_seqpred_model
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def get_seqpred_model(load_weights = True):
deepsea_cpu = nn.Sequential( # Sequential,
nn.Conv2d(4,320,(1, 8),(1, 1)),
nn.Threshold(0, 1e-06),
nn.MaxPool2d((1, 4),(1, 4)),
nn.Dropout(0.2),
nn.Conv2d(320,480,(1, 8),(1, 1)),
nn.Threshold(0, 1e-06),
nn.MaxPool2d((1, 4),(1, 4)),
nn.Dropout(0.2),
nn.Conv2d(480,960,(1, 8),(1, 1)),
nn.Threshold(0, 1e-06),
nn.Dropout(0.5),
Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(50880,925)), # Linear,
nn.Threshold(0, 1e-06),
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(925,919)), # Linear,
nn.Sigmoid(),
)
if load_weights:
deepsea_cpu.load_state_dict(torch.load('model_files/deepsea_cpu.pth'))
return nn.Sequential(ReCodeAlphabet(), ConcatenateRC(), deepsea_cpu, AverageRC())
示例5: load_test_model
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_test_model(model, config):
"""
:param model: initial model
:param config: config
:return: loaded model
"""
if config.t_model is None:
test_model_dir = config.save_best_model_dir
test_model_name = "{}.pt".format(config.model_name)
test_model_path = os.path.join(test_model_dir, test_model_name)
print("load default model from {}".format(test_model_path))
else:
test_model_path = config.t_model
print("load user model from {}".format(test_model_path))
model.load_state_dict(torch.load(test_model_path))
return model
示例6: load_checkpoint
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_checkpoint(self, checkpoint):
checkpoint = torch.load(checkpoint)
opt = checkpoint['opt']
opt.use_external_captions = False
vocab = Vocab.from_pickle(pjoin(opt.vocab_path, '%s_vocab.pkl' % opt.data_name))
opt.vocab_size = len(vocab)
from model import VSE
self.model = VSE(opt)
self.model.load_state_dict(checkpoint['model'])
self.projector = vocab
self.model.img_enc.eval()
self.model.txt_enc.eval()
for p in self.model.img_enc.parameters():
p.requires_grad = False
for p in self.model.txt_enc.parameters():
p.requires_grad = False
示例7: load_model
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_model(model, optimizer, scheduler, path, num_epochs, start_time=time.time()):
epoch = num_epochs
while epoch > 0 and not os.path.isfile('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)):
epoch -= 1
if epoch > 0:
model_path = '{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch)
model_state_dict = torch.load('{path}_model_{epoch:d}.pth'.format(path=path, epoch=epoch))
model.load_state_dict(model_state_dict)
if optimizer is not None:
optimizer_state_dict = torch.load('{path}_optimizer_{epoch:d}.pth'.format(path=path, epoch=epoch))
optimizer.load_state_dict(optimizer_state_dict)
if scheduler is not None:
scheduler_state_dict = torch.load('{path}_scheduler_{epoch:d}.pth'.format(path=path, epoch=epoch))
scheduler.best = scheduler_state_dict['best']
scheduler.cooldown_counter = scheduler_state_dict['cooldown_counter']
scheduler.num_bad_epochs = scheduler_state_dict['num_bad_epochs']
scheduler.last_epoch = scheduler_state_dict['last_epoch']
print('{epoch:4d}/{num_epochs:4d} e; '.format(epoch=epoch, num_epochs=num_epochs), end='')
print('load {path}; '.format(path=model_path), end='')
print('{time:8.3f} s'.format(time=time.time()-start_time))
return epoch
示例8: init_truncated_normal
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def init_truncated_normal(model, aux_str=''):
if model is None: return None
init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
.format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
if os.path.isfile(init_path):
model.load_state_dict(torch.load(init_path))
print('load init weight: {init_path}'.format(init_path=init_path))
else:
if isinstance(model, nn.ModuleList):
[truncated_normal(sub) for sub in model]
else:
truncated_normal(model)
print('generate init weight: {init_path}'.format(init_path=init_path))
torch.save(model.state_dict(), init_path)
print('save init weight: {init_path}'.format(init_path=init_path))
return model
示例9: print_mutation
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def print_mutation(hyp, results, bucket=''):
# Print mutation results to evolve.txt (for use with train.py --evolve)
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
c = '%10.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
if bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt
with open('evolve.txt', 'a') as f: # append result
f.write(c + b + '\n')
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness
if bucket:
os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
示例10: load_checkpoint
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_checkpoint(self, file_name):
filename = self.config.checkpoint_dir + file_name
try:
self.logger.info("Loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
self.current_episode = checkpoint['episode']
self.current_iteration = checkpoint['iteration']
self.policy_model.load_state_dict(checkpoint['state_dict'])
self.optim.load_state_dict(checkpoint['optimizer'])
self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
.format(self.config.checkpoint_dir, checkpoint['episode'], checkpoint['iteration']))
except OSError as e:
self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
self.logger.info("**First time to train**")
示例11: load_checkpoint
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_checkpoint(self, filename):
filename = self.config.checkpoint_dir + filename
try:
self.logger.info("Loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
self.current_epoch = checkpoint['epoch']
self.current_iteration = checkpoint['iteration']
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
.format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
except OSError as e:
self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
self.logger.info("**First time to train**")
示例12: load_checkpoint
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_checkpoint(self, filename):
filename = self.config.checkpoint_dir + filename
try:
self.logger.info("Loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
self.current_epoch = checkpoint['epoch']
self.current_iteration = checkpoint['iteration']
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
.format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
except OSError as e:
self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
self.logger.info("**First time to train**")
示例13: load_checkpoint
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def load_checkpoint(self, file_name):
filename = self.config.checkpoint_dir + file_name
try:
self.logger.info("Loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
self.current_epoch = checkpoint['epoch']
self.current_iteration = checkpoint['iteration']
self.netG.load_state_dict(checkpoint['G_state_dict'])
self.optimG.load_state_dict(checkpoint['G_optimizer'])
self.netD.load_state_dict(checkpoint['D_state_dict'])
self.optimD.load_state_dict(checkpoint['D_optimizer'])
self.fixed_noise = checkpoint['fixed_noise']
self.manual_seed = checkpoint['manual_seed']
self.logger.info("Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
.format(self.config.checkpoint_dir, checkpoint['epoch'], checkpoint['iteration']))
except OSError as e:
self.logger.info("No checkpoint exists from '{}'. Skipping...".format(self.config.checkpoint_dir))
self.logger.info("**First time to train**")
示例14: __getitem__
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def __getitem__(self, idx):
fid = self.data_set_list[idx]
if self.read_features:
features = []
for i in range(self.sequence_length):
feature_path = os.path.join(
self.features_dir,
self.frames_metadata[fid + i]['cur_frame'] + '.pytar')
features.append(torch.load(feature_path))
input = torch.stack(features)
else:
image = self.load_and_resize(
os.path.join(self.root_dir, 'images', fid))
segment = self.load_and_resize_segmentation(
os.path.join(self.root_dir, 'walkable', fid))
# The two 0s are just place holders. They can be replaced by any values
return (image, segment, 0, 0, ['images' + fid])
示例15: set_model
# 需要導入模塊: import torch [as 別名]
# 或者: from torch import load [as 別名]
def set_model(self):
''' Setup ASR model and optimizer '''
# Model
self.model = RNNLM(self.vocab_size, **
self.config['model']).to(self.device)
self.verbose(self.model.create_msg())
# Losses
self.seq_loss = torch.nn.CrossEntropyLoss(ignore_index=0)
# Optimizer
self.optimizer = Optimizer(
self.model.parameters(), **self.config['hparas'])
# Enable AMP if needed
self.enable_apex()
# load pre-trained model
if self.paras.load:
self.load_ckpt()
ckpt = torch.load(self.paras.load, map_location=self.device)
self.model.load_state_dict(ckpt['model'])
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
self.step = ckpt['global_step']
self.verbose('Load ckpt from {}, restarting at step {}'.format(
self.paras.load, self.step))