当前位置: 首页>>代码示例>>Python>>正文


Python utils.save_checkpoint方法代码示例

本文整理汇总了Python中utils.utils.save_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python utils.save_checkpoint方法的具体用法?Python utils.save_checkpoint怎么用?Python utils.save_checkpoint使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils.utils的用法示例。


在下文中一共展示了utils.save_checkpoint方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: save_checkpoint

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import save_checkpoint [as 别名]
def save_checkpoint(self, epoch, best_prec, is_best):
        state = {
                'epoch': epoch,
                'best_prec': best_prec,
                'state_dict': self.model.state_dict(),
                'optim' : self.optim.state_dict(),
                }
        utils.save_checkpoint(state, is_best, dirpath=self.dirpath, filename='model_checkpoint.pkl')
        if(is_best):
            path_save = os.path.join(self.dirpath, 'weight_best.pkl')
            torch.save({'state_dict': self.model.state_dict()}, path_save) 
开发者ID:wyf2017,项目名称:DSMnet,代码行数:13,代码来源:stereo.py

示例2: start

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import save_checkpoint [as 别名]
def start(self):
        args = self.args
        if args.mode == 'test':
            self.validate()
            return
    
        losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = [], [], [], [], [], [], []
        path_val = os.path.join(self.dirpath, "loss.pkl")
        if(os.path.exists(path_val)):
            state_val = torch.load(path_val)
            losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val = state_val
        # 开始训练模型
        plt.figure(figsize=(18, 5))
        time_start = time.time()
        epoch0 = self.epoch
        for epoch in range(epoch0, args.epochs):
            self.epoch = epoch
            self.lr_adjust(self.optim, args.lr_epoch0, args.lr_stride, args.lr, epoch) # 自定义的lr_adjust函数,见上
            self.lossfun.Weight_Adjust_levels(epoch)
            msg = 'lr: %.6f | weight of levels: %s' % (self.optim.param_groups[0]['lr'], str(self.lossfun.weight_levels))
            logging.info(msg)
    
            # train for one epoch
            mloss, mEPE, mD1 = self.train()
            losses.append(mloss)
            EPEs.append(mEPE)
            D1s.append(mD1)
    
            if(epoch % self.args.val_freq == 0) or (epoch == args.epochs-1):
                # evaluate on validation set
                mloss_val, mEPE_val, mD1_val = self.validate()
                epochs_val.append(epoch)
                losses_val.append(mloss_val)
                EPEs_val.append(mEPE_val)
                D1s_val.append(mD1_val)
        
                # remember best prec@1 and save checkpoint
                is_best = mD1_val < self.best_prec
                self.best_prec = min(mD1_val, self.best_prec)
                self.save_checkpoint(epoch, self.best_prec, is_best)
                torch.save([losses, EPEs, D1s, epochs_val, losses_val, EPEs_val, D1s_val], path_val)
                
                # plt
                m, n = 1, 3
                ax1 = plt.subplot(m, n, 1)
                ax2 = plt.subplot(m, n, 2)
                ax3 = plt.subplot(m, n, 3)
                plt.sca(ax1); plt.cla(); plt.xlabel("epoch"); plt.ylabel("Loss")
                plt.plot(np.array(losses), label='train'); plt.plot(np.array(epochs_val), np.array(losses_val), label='val'); plt.legend()
                plt.sca(ax2); plt.cla(); plt.xlabel("epoch"); plt.ylabel("EPE")
                plt.plot(np.array(EPEs), label='train'); plt.plot(np.array(epochs_val), np.array(EPEs_val), label='val'); plt.legend()
                plt.sca(ax3); plt.cla(); plt.xlabel("epoch"); plt.ylabel("D1")
                plt.plot(np.array(D1s), label='train'); plt.plot(np.array(epochs_val), np.array(D1s_val), label='val'); plt.legend()
                plt.savefig("check_%s_%s_%s_%s.png" % (args.mode, args.dataset, args.net, args.loss_name))
            
            time_curr = (time.time() - time_start)/3600.0
            time_all =  time_curr*(args.epochs - epoch0)/(epoch + 1 - epoch0)
            msg = 'Progress: %.2f | %.2f (hour)\n' % (time_curr, time_all)
            logging.info(msg) 
开发者ID:wyf2017,项目名称:DSMnet,代码行数:61,代码来源:stereo.py

示例3: run

# 需要导入模块: from utils import utils [as 别名]
# 或者: from utils.utils import save_checkpoint [as 别名]
def run(self, epochs, train_loader, val_loader, test_loader, log_interval):
        cuda = self.device != -1
        with torch.cuda.device(self.device):
            trainer = create_supervised_trainer(self.model, self.optimizer, self.loss_fn, cuda=cuda)
            evaluator = create_supervised_evaluator(self.model, metrics=self.metrics, y_to_score=self.y_to_score, pred_to_score=self.pred_to_score, cuda=cuda)

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_training_loss(engine):
            iteration = (engine.state.iteration - 1) % len(train_loader) + 1
            if iteration % log_interval == 0:
                print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                      "".format(engine.state.epoch, iteration, len(train_loader), engine.state.output))
                self.writer.add_scalar("train/loss", engine.state.output, engine.state.iteration)

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            state_metrics = evaluator.state.metrics

            state_metric_keys = list(self.metrics.keys())
            state_metric_vals = [state_metrics[k] for k in state_metric_keys]
            format_str = 'Validation Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys])
            print(format_str.format(*([engine.state.epoch] + state_metric_vals)))
            for i, k in enumerate(state_metric_keys):
                self.writer.add_scalar(f'dev/{k}', state_metric_vals[i], engine.state.epoch)

            if state_metric_vals[0] > self.best_score:
                state_dict = {
                    'epoch': engine.state.epoch,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'eval_metric': state_metric_vals[0]
                }
                utils.save_checkpoint(state_dict, self.model_id)
                self.best_score = state_metric_vals[0]

        @trainer.on(Events.COMPLETED)
        def log_test_results(engine):
            checkpoint = torch.load(self.model_id)
            self.model.load_state_dict(checkpoint['state_dict'])

            evaluator.run(test_loader)
            state_metrics = evaluator.state.metrics

            state_metric_keys = list(self.metrics.keys())
            state_metric_vals = [state_metrics[k] for k in state_metric_keys]
            format_str = 'Test Results - Epoch: {} ' + ' '.join([k + ': {:.4f}' for k in state_metric_keys])
            print(format_str.format(*([engine.state.epoch] + state_metric_vals)))
            for i, k in enumerate(state_metric_keys):
                self.writer.add_scalar(f'test/{k}', state_metric_vals[i], engine.state.epoch)

        trainer.run(train_loader, max_epochs=epochs)

        self.writer.close() 
开发者ID:tuzhucheng,项目名称:sentence-similarity,代码行数:56,代码来源:__init__.py


注:本文中的utils.utils.save_checkpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。