当前位置: 首页>>代码示例>>Python>>正文


Python utils.save_checkpoint方法代码示例

本文整理汇总了Python中utils.save_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python utils.save_checkpoint方法的具体用法?Python utils.save_checkpoint怎么用?Python utils.save_checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils的用法示例。


在下文中一共展示了utils.save_checkpoint方法的3个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train_and_eval

# 需要导入模块: import utils [as 别名]
# 或者: from utils import save_checkpoint [as 别名]
def train_and_eval(net, train_loader, val_loader, optimizer, loss_fn, metrics, params, model_dir, restore=None):
    """
    Train and evaluate every epoch of a model.
    net: The model. 
    train/val loader: The data loaders
    params: The parameters parsed from JSON file 
    restore: if there is a checkpoint restore from that point. 
    """
    best_val_acc = 0.0 
    if restore is not None:
        restore_file = os.path.join(args.param_path, args.resume_path + '_pth.tar')
        logging.info("Loaded checkpoints from:{}".format(restore_file))
        utils.load_checkpoint(restore_file, net, optimizer)

    for ep in range(params.num_epochs):
        logging.info("Running epoch: {}/{}".format(ep+1, params.num_epochs))

        # train one epoch 
        train(net, train_loader, loss_fn, params, metrics, optimizer)

        val_metrics = evaluate(net, val_loader, loss_fn, params, metrics)

        val_acc = val_metrics['accuracy']
        isbest = val_acc >= best_val_acc 

        utils.save_checkpoint({"epoch":ep, "state_dict":net.state_dict(), "optimizer":optimizer.state_dict()}, 
        isBest=isbest, ckpt_dir=model_dir)
    
        if isbest:
            # if the accuracy is great  save it to best.json 
            logging.info("New best accuracy found!")
            best_val_acc = val_acc 
            best_json_path = os.path.join(model_dir, "best_model_params.json")
            utils.save_dict_to_json(val_metrics, best_json_path)
        
        last_acc_path = os.path.join(model_dir, 'last_acc_metrics.json')
        utils.save_dict_to_json(val_metrics, last_acc_path) 
开发者ID:aicaffeinelife,项目名称:Pytorch-STN,代码行数:39,代码来源:train.py

示例2: run

# 需要导入模块: import utils [as 别名]
# 或者: from utils import save_checkpoint [as 别名]
def run(self):
        """
        Procedure of training. This run describes the entire training procedure.
        :return:
        """
        train_queue, valid_queue, test_queue, criterion = self.initialize_run()
        args = self.args
        model, optimizer, scheduler = self.initialize_model()
        fitness_dict = {}
        self.optimizer = optimizer
        self.scheduler = scheduler
        logging.info(">> Begin the search with supernet method :".format(args.supernet_train_method))

        for epoch in range(args.epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]

            train_acc, train_obj = self.train_fn(train_queue, valid_queue, model, criterion, optimizer, lr)
            self.logging_fn(train_acc, train_obj, epoch, 'Train', display_dict={'lr': lr})

            # validation
            valid_acc, valid_obj = self.validate_model(model, valid_queue, self.model_spec_id, self.model_spec)
            self.logging_fn(valid_acc, valid_obj, epoch, 'Valid')

            if not self.check_should_save(epoch):
                continue
            # evaluate process.
            self.save_duplicate_arch_pool('valid', epoch)
            fitness_dict = self.evaluate(epoch, test_queue, fitnesses_dict=fitness_dict, train_queue=train_queue)
            utils.save_checkpoint(model, optimizer, self.running_stats, self.exp_dir)
            self.save_results(epoch, rank_details=True)

        # add later, return the model specs that is evaluated across the time.
        # Process the ranking in the end, return the best of training.
        ep_k = [k for k in self.ranking_per_epoch.keys()][-1]
        best_id = self.ranking_per_epoch[ep_k][-1][1].geno_id
        return best_id, self.search_space.nasbench_model_specs[best_id] 
开发者ID:kcyu2014,项目名称:eval-nas,代码行数:39,代码来源:nasbench_weight_sharing_policy.py

示例3: train_and_evaluate

# 需要导入模块: import utils [as 别名]
# 或者: from utils import save_checkpoint [as 别名]
def train_and_evaluate(model, train_data, val_data, optimizer, scheduler, params, model_dir, restore_file=None):
    """Train the model and evaluate every epoch."""
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
        
    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(1, params.epoch_num + 1):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch, params.epoch_num))

        # Compute number of batches in one epoch
        params.train_steps = params.train_size // params.batch_size
        params.val_steps = params.val_size // params.batch_size

        # data iterator for training
        train_data_iterator = data_loader.data_iterator(train_data, shuffle=True)
        # Train for one epoch on training set
        train(model, train_data_iterator, optimizer, scheduler, params)

        # data iterator for evaluation
        train_data_iterator = data_loader.data_iterator(train_data, shuffle=False)
        val_data_iterator = data_loader.data_iterator(val_data, shuffle=False)

        # Evaluate for one epoch on training set and validation set
        params.eval_steps = params.train_steps
        train_metrics = evaluate(model, train_data_iterator, params, mark='Train')
        params.eval_steps = params.val_steps
        val_metrics = evaluate(model, val_data_iterator, params, mark='Val')
        
        val_f1 = val_metrics['f1']
        improve_f1 = val_f1 - best_val_f1

        # Save weights of the network
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        optimizer_to_save = optimizer.optimizer if args.fp16 else optimizer
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model_to_save.state_dict(),
                               'optim_dict': optimizer_to_save.state_dict()},
                               is_best=improve_f1>0,
                               checkpoint=model_dir)
        if improve_f1 > 0:
            logging.info("- Found new best F1")
            best_val_f1 = val_f1
            if improve_f1 < params.patience:
                patience_counter += 1
            else:
                patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping and logging best f1
        if (patience_counter >= params.patience_num and epoch > params.min_epoch_num) or epoch == params.epoch_num:
            logging.info("Best val f1: {:05.2f}".format(best_val_f1))
            break 
开发者ID:lemonhu,项目名称:NER-BERT-pytorch,代码行数:61,代码来源:train.py


注:本文中的utils.save_checkpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。