当前位置: 首页>>代码示例>>Python>>正文


Python torch.save方法代码示例

本文整理汇总了Python中torch.save方法的典型用法代码示例。如果您正苦于以下问题:Python torch.save方法的具体用法?Python torch.save怎么用?Python torch.save使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch的用法示例。


在下文中一共展示了torch.save方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: save

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save(self, ckpt_path, epoch):
    '''
    Save checkpoint.
    '''
    for name, net in self.nets.items():
      if isinstance(net, torch.nn.DataParallel):
        module = net.module
      else:
        module = net

      path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch))
      torch.save(module.state_dict(), path)

    for name, optimizer in self.optimizers.items():
      path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch))
      torch.save(optimizer.state_dict(), path) 
开发者ID:jthsieh,项目名称:DDPAE-video-prediction,代码行数:18,代码来源:base_model.py

示例2: convert

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def convert(src, dst):
    """Convert keys in pycls pretrained RegNet models to mmdet style."""
    # load caffe model
    regnet_model = torch.load(src)
    blobs = regnet_model['model_state']
    # convert to pytorch style
    state_dict = OrderedDict()
    converted_names = set()
    for key, weight in blobs.items():
        if 'stem' in key:
            convert_stem(key, weight, state_dict, converted_names)
        elif 'head' in key:
            convert_head(key, weight, state_dict, converted_names)
        elif key.startswith('s'):
            convert_reslayer(key, weight, state_dict, converted_names)

    # check if all layers are converted
    for key in blobs:
        if key not in converted_names:
            print(f'not converted: {key}')
    # save checkpoint
    checkpoint = dict()
    checkpoint['state_dict'] = state_dict
    torch.save(checkpoint, dst) 
开发者ID:open-mmlab,项目名称:mmdetection,代码行数:26,代码来源:regnet2mmdet.py

示例3: save_model_all

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_model_all(model, save_dir, model_name, epoch):
    """
    :param model:  nn model
    :param save_dir: save model direction
    :param model_name:  model name
    :param epoch:  epoch
    :return:  None
    """
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_prefix = os.path.join(save_dir, model_name)
    save_path = '{}_epoch_{}.pt'.format(save_prefix, epoch)
    print("save all model to {}".format(save_path))
    output = open(save_path, mode="wb")
    torch.save(model.state_dict(), output)
    # torch.save(model.state_dict(), save_path)
    output.close() 
开发者ID:bamtercelboo,项目名称:pytorch_NER_BiLSTM_CNN_CRF,代码行数:19,代码来源:utils.py

示例4: save_best_model

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_best_model(model, save_dir, model_name, best_eval):
    """
    :param model:  nn model
    :param save_dir:  save model direction
    :param model_name:  model name
    :param best_eval:  eval best
    :return:  None
    """
    if best_eval.current_dev_score >= best_eval.best_dev_score:
        if not os.path.isdir(save_dir): os.makedirs(save_dir)
        model_name = "{}.pt".format(model_name)
        save_path = os.path.join(save_dir, model_name)
        print("save best model to {}".format(save_path))
        # if os.path.exists(save_path):  os.remove(save_path)
        output = open(save_path, mode="wb")
        torch.save(model.state_dict(), output)
        # torch.save(model.state_dict(), save_path)
        output.close()
        best_eval.early_current_patience = 0


# adjust lr 
开发者ID:bamtercelboo,项目名称:pytorch_NER_BiLSTM_CNN_CRF,代码行数:24,代码来源:utils.py

示例5: save_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(state, weights_dir = '' ):
    """[summary]
    
    [description]
    
    Arguments:
        state {[type]} -- [description] a dict describe some params
        is_best {bool} -- [description] a bool value
    
    Keyword Arguments:
        filename {str} -- [description] (default: {'checkpoint.pth.tar'})
    """
    if not os.path.exists(weights_dir):
        os.makedirs(weights_dir)
    
    epoch = state['epoch']

    file_path = os.path.join(weights_dir, 'model-{:04d}.pth.tar'.format(int(epoch)))  
    torch.save(state, file_path)
    

#############################################
# loss function
############################################# 
开发者ID:songdejia,项目名称:DeepLab_v3_plus,代码行数:26,代码来源:util.py

示例6: init_truncated_normal

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def init_truncated_normal(model, aux_str=''):
    if model is None: return None
    init_path = '{path}/{in_dim:d}_{out_dim:d}{aux_str}.pth' \
                .format(path=path, in_dim=model.in_features, out_dim=model.out_features, aux_str=aux_str)
    if os.path.isfile(init_path):
        model.load_state_dict(torch.load(init_path))
        print('load init weight: {init_path}'.format(init_path=init_path))
    else:
        if isinstance(model, nn.ModuleList):
            [truncated_normal(sub) for sub in model]
        else:
            truncated_normal(model)
        print('generate init weight: {init_path}'.format(init_path=init_path))
        torch.save(model.state_dict(), init_path)
        print('save init weight: {init_path}'.format(init_path=init_path))
    
    return model 
开发者ID:kibok90,项目名称:cvpr2018-hnd,代码行数:19,代码来源:models.py

示例7: print_mutation

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def print_mutation(hyp, results, bucket=''):
    # Print mutation results to evolve.txt (for use with train.py --evolve)
    a = '%10s' * len(hyp) % tuple(hyp.keys())  # hyperparam keys
    b = '%10.3g' * len(hyp) % tuple(hyp.values())  # hyperparam values
    c = '%10.3g' * len(results) % results  # results (P, R, mAP, F1, test_loss)
    print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))

    if bucket:
        os.system('gsutil cp gs://%s/evolve.txt .' % bucket)  # download evolve.txt

    with open('evolve.txt', 'a') as f:  # append result
        f.write(c + b + '\n')
    x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rows
    np.savetxt('evolve.txt', x[np.argsort(-fitness(x))], '%10.3g')  # save sort by fitness

    if bucket:
        os.system('gsutil cp evolve.txt gs://%s' % bucket)  # upload evolve.txt 
开发者ID:zbyuan,项目名称:pruning_yolov3,代码行数:19,代码来源:utils.py

示例8: save_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:21,代码来源:condensenet.py

示例9: save_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, filename='checkpoint.pth.tar', is_best=0):
        """
        Saving the latest checkpoint of the training
        :param filename: filename which will contain the state
        :param is_best: flag is it is the best model
        :return:
        """
        state = {
            'epoch': self.current_epoch + 1,
            'iteration': self.current_iteration,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + filename)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + filename,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:21,代码来源:erfnet.py

示例10: save_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best = 0):
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'G_state_dict': self.netG.state_dict(),
            'G_optimizer': self.optimG.state_dict(),
            'D_state_dict': self.netD.state_dict(),
            'D_optimizer': self.optimD.state_dict(),
            'fixed_noise': self.fixed_noise,
            'manual_seed': self.manual_seed
        }
        # Save the state
        torch.save(state, self.config.checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(self.config.checkpoint_dir + file_name,
                            self.config.checkpoint_dir + 'model_best.pth.tar') 
开发者ID:moemen95,项目名称:Pytorch-Project-Template,代码行数:19,代码来源:dcgan.py

示例11: _save_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def _save_checkpoint(self, epoch, acc):
        """
        Saves a checkpoint of the network and other variables.
        Only save the best and latest epoch.
        """
        net_type = type(self.net).__name__
        if epoch - self.eval_freq != self.best_epoch:
            pre_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch - self.eval_freq))
            if os.path.isfile(pre_save):
                os.remove(pre_save)
        cur_save = os.path.join(self.logdir, '{}_{:03d}.pkl'.format(net_type, epoch))
        state = {
            'epoch': epoch,
            'acc': acc,
            'net_type': net_type,
            'net': self.net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            #'scheduler': self.scheduler.state_dict(),
            'use_gpu': self.use_gpu,
            'save_time': datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
        }
        torch.save(state, cur_save)
        return True 
开发者ID:miraiaroha,项目名称:ACAN,代码行数:25,代码来源:trainer.py

示例12: save_checkpoint

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save_checkpoint(self, f_name, metric, score, show_msg=True):
        '''' 
        Ckpt saver
            f_name - <str> the name phnof ckpt file (w/o prefix) to store, overwrite if existed
            score  - <float> The value of metric used to evaluate model
        '''
        ckpt_path = os.path.join(self.ckpdir, f_name)
        full_dict = {
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.get_opt_state_dict(),
            "global_step": self.step,
            metric: score
        }
        # Additional modules to save
        # if self.amp:
        #    full_dict['amp'] = self.amp_lib.state_dict()
        if self.emb_decoder is not None:
            full_dict['emb_decoder'] = self.emb_decoder.state_dict()

        torch.save(full_dict, ckpt_path)
        if show_msg:
            self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
                         format(human_format(self.step), metric, score, ckpt_path)) 
开发者ID:Alexander-H-Liu,项目名称:End-to-end-ASR-Pytorch,代码行数:25,代码来源:solver.py

示例13: train

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def train(self, optimizer = None, epoches = 10, save_name=None):
        for i in range(epoches):
            print("Epoch: ", i+1)
            self.train_epoch(optimizer, i+1, epoches+1)
            cur_correct = self.test()
            if cur_correct >= self.littlemax_correct:
                self.littlemax_correct = cur_correct
                self.cur_model = self.model
                print("write cur bset model")

            if cur_correct > self.max_correct:
                self.max_correct = cur_correct
                if save_name:
                    torch.save(self.model, str(save_name))
            print('amazon to webcam max correct: {} max accuracy{: .2f}%\n'.format(
                self.max_correct, 100.0 * self.max_correct / self.len_target_dataset))

        print("Finished fine tuning.") 
开发者ID:jindongwang,项目名称:transferlearning,代码行数:20,代码来源:finetune.py

示例14: save

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def save(self, modules, global_step, force=False):
        """
        Save iff (force given or global_step % keep_tmp_itr == 0)
        :param modules: dictionary name -> nn.Module
        :param global_step: current step
        :return: bool, Whether previous checkpoints were removed
        """
        if not (force or (global_step % self.keep_tmp_itr == 0)):
            return False
        assert self._out_dir is not None
        current_ckpt_p = self._save(modules, global_step)
        self.ckpts_since_last_permanent += 1
        if self.ckpts_since_last_permanent == self.keep_every:
            self._remove_previous(current_ckpt_p)
            self.ckpts_since_last_permanent = 0
            return True
        return False 
开发者ID:fab-jul,项目名称:L3C-PyTorch,代码行数:19,代码来源:saver.py

示例15: __init__

# 需要导入模块: import torch [as 别名]
# 或者: from torch import save [as 别名]
def __init__(self, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None, wsddn_premodel=None):
    self.net = network
    self.imdb = imdb
    self.roidb = roidb
    self.valroidb = valroidb
    self.output_dir = output_dir
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model
    self.wsddn_premodel = wsddn_premodel 
开发者ID:Sunarker,项目名称:Collaborative-Learning-for-Weakly-Supervised-Object-Detection,代码行数:15,代码来源:train_val.py


注:本文中的torch.save方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。