本文整理汇总了Python中torch.save方法的典型用法代码示例。如果您正苦于以下问题:Python torch.save方法的具体用法?Python torch.save怎么用?Python torch.save使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch
的用法示例。
在下文中一共展示了torch.save方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: save
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save(self, ckpt_path, epoch):
'''
Save checkpoint.
'''
for name, net in self.nets.items():
if isinstance(net, torch.nn.DataParallel):
module = net.module
else:
module = net
path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch))
torch.save(module.state_dict(), path)
for name, optimizer in self.optimizers.items():
path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch))
torch.save(optimizer.state_dict(), path)
示例2: convert
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def convert(src, dst):
"""Convert keys in pycls pretrained RegNet models to mmdet style."""
# load caffe model
regnet_model = torch.load(src)
blobs = regnet_model['model_state']
# convert to pytorch style
state_dict = OrderedDict()
converted_names = set()
for key, weight in blobs.items():
if 'stem' in key:
convert_stem(key, weight, state_dict, converted_names)
elif 'head' in key:
convert_head(key, weight, state_dict, converted_names)
elif key.startswith('s'):
convert_reslayer(key, weight, state_dict, converted_names)
# check if all layers are converted
for key in blobs:
if key not in converted_names:
print(f'not converted: {key}')
# save checkpoint
checkpoint = dict()
checkpoint['state_dict'] = state_dict
torch.save(checkpoint, dst)
示例3: save_model_all
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_model_all(model, save_dir, model_name, epoch):
"""
:param model: nn model
:param save_dir: save model direction
:param model_name: model name
:param epoch: epoch
:return: None
"""
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, model_name)
save_path = '{}_epoch_{}.pt'.format(save_prefix, epoch)
print("save all model to {}".format(save_path))
output = open(save_path, mode="wb")
torch.save(model.state_dict(), output)
# torch.save(model.state_dict(), save_path)
output.close()
示例4: save_best_model
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_best_model(model, save_dir, model_name, best_eval):
"""
:param model: nn model
:param save_dir: save model direction
:param model_name: model name
:param best_eval: eval best
:return: None
"""
if best_eval.current_dev_score >= best_eval.best_dev_score:
if not os.path.isdir(save_dir): os.makedirs(save_dir)
model_name = "{}.pt".format(model_name)
save_path = os.path.join(save_dir, model_name)
print("save best model to {}".format(save_path))
# if os.path.exists(save_path): os.remove(save_path)
output = open(save_path, mode="wb")
torch.save(model.state_dict(), output)
# torch.save(model.state_dict(), save_path)
output.close()
best_eval.early_current_patience = 0
# adjust lr
示例5: save_checkpoint
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(state, weights_dir = '' ):
"""[summary]
[description]
Arguments:
state {[type]} -- [description] a dict describe some params
is_best {bool} -- [description] a bool value
Keyword Arguments:
filename {str} -- [description] (default: {'checkpoint.pth.tar'})
"""
if not os.path.exists(weights_dir):
os.makedirs(weights_dir)
epoch = state['epoch']
file_path = os.path.join(weights_dir, 'model-{:04d}.pth.tar'.format(int(epoch)))
torch.save(state, file_path)
#############################################
# loss function
#############################################
示例6: init_truncated_normal
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def init_truncated_normal(model, aux_str=''):
if model is None: return None
init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
.format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
if os.path.isfile(init_path):
model.load_state_dict(torch.load(init_path))
print('load init weight: {init_path}'.format(init_path=init_path))
else:
if isinstance(model, nn.ModuleList):
[truncated_normal(sub) for sub in model]
else:
truncated_normal(model)
print('generate init weight: {init_path}'.format(init_path=init_path))
torch.save(model.state_dict(), init_path)
print('save init weight: {init_path}'.format(init_path=init_path))
return model
示例7: print_mutation
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def print_mutation(hyp, results, bucket=''):
# Print mutation results to evolve.txt (for use with train.py --evolve)
a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
c = '%10.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
if bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % bucket) # download evolve.txt
with open('evolve.txt', 'a') as f: # append result
f.write(c + b + '\n')
x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g') # save sort by fitness
if bucket:
os.system('gsutil cp evolve.txt gs://%s' % bucket) # upload evolve.txt
示例8: save_checkpoint
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
"""
Saving the latest checkpoint of the training
:param filename: filename which will contain the state
:param is_best: flag is it is the best model
:return:
"""
state = {
'epoch': self.current_epoch,
'iteration': self.current_iteration,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
# Save the state
torch.save(state, self.config.checkpoint_dir + filename)
# If it is the best copy it to another file 'model_best.pth.tar'
if is_best:
shutil.copyfile(self.config.checkpoint_dir + filename,
self.config.checkpoint_dir + 'model_best.pth.tar')
示例9: save_checkpoint
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
"""
Saving the latest checkpoint of the training
:param filename: filename which will contain the state
:param is_best: flag is it is the best model
:return:
"""
state = {
'epoch': self.current_epoch + 1,
'iteration': self.current_iteration,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
}
# Save the state
torch.save(state, self.config.checkpoint_dir + filename)
# If it is the best copy it to another file 'model_best.pth.tar'
if is_best:
shutil.copyfile(self.config.checkpoint_dir + filename,
self.config.checkpoint_dir + 'model_best.pth.tar')
示例10: save_checkpoint
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best = 0):
state = {
'epoch': self.current_epoch,
'iteration': self.current_iteration,
'G_state_dict': self.netG.state_dict(),
'G_optimizer': self.optimG.state_dict(),
'D_state_dict': self.netD.state_dict(),
'D_optimizer': self.optimD.state_dict(),
'fixed_noise': self.fixed_noise,
'manual_seed': self.manual_seed
}
# Save the state
torch.save(state, self.config.checkpoint_dir + file_name)
# If it is the best copy it to another file 'model_best.pth.tar'
if is_best:
shutil.copyfile(self.config.checkpoint_dir + file_name,
self.config.checkpoint_dir + 'model_best.pth.tar')
示例11: _save_checkpoint
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def _save_checkpoint(self, epoch, acc):
"""
Saves a checkpoint of the network and other variables.
Only save the best and latest epoch.
"""
net_type = type(self.net).__name__
if epoch - self.eval_freq != self.best_epoch:
pre_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch - self.eval_freq))
if os.path.isfile(pre_save):
os.remove(pre_save)
cur_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch))
state = {
'epoch': epoch,
'acc': acc,
'net_type': net_type,
'net': self.net.state_dict(),
'optimizer': self.optimizer.state_dict(),
#'scheduler': self.scheduler.state_dict(),
'use_gpu': self.use_gpu,
'save_time': datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
}
torch.save(state, cur_save)
return True
示例12: save_checkpoint
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, f_name, metric, score, show_msg=True):
''''
Ckpt saver
f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed
score - <float> The value of metric used to evaluate model
'''
ckpt_path = os.path.join(self.ckpdir, f_name)
full_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.get_opt_state_dict(),
"global_step": self.step,
metric: score
}
# Additional modules to save
# if self.amp:
# full_dict['amp'] = self.amp_lib.state_dict()
if self.emb_decoder is not None:
full_dict['emb_decoder'] = self.emb_decoder.state_dict()
torch.save(full_dict, ckpt_path)
if show_msg:
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
format(human_format(self.step), metric, score, ckpt_path))
示例13: train
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def train(self, optimizer = None, epoches = 10, save_name=None):
for i in range(epoches):
print("Epoch: ", i+1)
self.train_epoch(optimizer, i+1, epoches+1)
cur_correct = self.test()
if cur_correct >= self.littlemax_correct:
self.littlemax_correct = cur_correct
self.cur_model = self.model
print("write cur bset model")
if cur_correct > self.max_correct:
self.max_correct = cur_correct
if save_name:
torch.save(self.model, str(save_name))
print('amazon to webcam max correct: {} max accuracy{: .2f}%\n'.format(
self.max_correct, 100.0 * self.max_correct / self.len_target_dataset))
print("Finished fine tuning.")
示例14: save
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save(self, modules, global_step, force=False):
"""
Save iff (force given or global_step % keep_tmp_itr == 0)
:param modules: dictionary name -> nn.Module
:param global_step: current step
:return: bool, Whether previous checkpoints were removed
"""
if not (force or (global_step % self.keep_tmp_itr == 0)):
return False
assert self._out_dir is not None
current_ckpt_p = self._save(modules, global_step)
self.ckpts_since_last_permanent += 1
if self.ckpts_since_last_permanent == self.keep_every:
self._remove_previous(current_ckpt_p)
self.ckpts_since_last_permanent = 0
return True
return False
示例15: __init__
# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None, wsddn_premodel=None):
self.net = network
self.imdb = imdb
self.roidb = roidb
self.valroidb = valroidb
self.output_dir = output_dir
self.tbdir = tbdir
# Simply put '_val' at the end to save the summaries from the validation set
self.tbvaldir = tbdir + '_val'
if not os.path.exists(self.tbvaldir):
os.makedirs(self.tbvaldir)
self.pretrained_model = pretrained_model
self.wsddn_premodel = wsddn_premodel
开发者ID:Sunarker,项目名称:Collaborative-Learning-for-Weakly-Supervised-Object-Detection,代码行数:15,代码来源:train_val.py