本文整理匯總了Python中utils.save_checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.save_checkpoint方法的具體用法?Python utils.save_checkpoint怎麽用?Python utils.save_checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類utils
的用法示例。
在下文中一共展示了utils.save_checkpoint方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: train_and_eval
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import save_checkpoint [as 別名]
def train_and_eval(net, train_loader, val_loader, optimizer, loss_fn, metrics, params, model_dir, restore=None):
"""
Train and evaluate every epoch of a model.
net: The model.
train/val loader: The data loaders
params: The parameters parsed from JSON file
restore: if there is a checkpoint restore from that point.
"""
best_val_acc = 0.0
if restore is not None:
restore_file = os.path.join(args.param_path, args.resume_path + '_pth.tar')
logging.info("Loaded checkpoints from:{}".format(restore_file))
utils.load_checkpoint(restore_file, net, optimizer)
for ep in range(params.num_epochs):
logging.info("Running epoch: {}/{}".format(ep+1, params.num_epochs))
# train one epoch
train(net, train_loader, loss_fn, params, metrics, optimizer)
val_metrics = evaluate(net, val_loader, loss_fn, params, metrics)
val_acc = val_metrics['accuracy']
isbest = val_acc >= best_val_acc
utils.save_checkpoint({"epoch":ep, "state_dict":net.state_dict(), "optimizer":optimizer.state_dict()},
isBest=isbest, ckpt_dir=model_dir)
if isbest:
# if the accuracy is great save it to best.json
logging.info("New best accuracy found!")
best_val_acc = val_acc
best_json_path = os.path.join(model_dir, "best_model_params.json")
utils.save_dict_to_json(val_metrics, best_json_path)
last_acc_path = os.path.join(model_dir, 'last_acc_metrics.json')
utils.save_dict_to_json(val_metrics, last_acc_path)
示例2: run
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import save_checkpoint [as 別名]
def run(self):
"""
Procedure of training. This run describes the entire training procedure.
:return:
"""
train_queue, valid_queue, test_queue, criterion = self.initialize_run()
args = self.args
model, optimizer, scheduler = self.initialize_model()
fitness_dict = {}
self.optimizer = optimizer
self.scheduler = scheduler
logging.info(">> Begin the search with supernet method :".format(args.supernet_train_method))
for epoch in range(args.epochs):
scheduler.step()
lr = scheduler.get_lr()[0]
train_acc, train_obj = self.train_fn(train_queue, valid_queue, model, criterion, optimizer, lr)
self.logging_fn(train_acc, train_obj, epoch, 'Train', display_dict={'lr': lr})
# validation
valid_acc, valid_obj = self.validate_model(model, valid_queue, self.model_spec_id, self.model_spec)
self.logging_fn(valid_acc, valid_obj, epoch, 'Valid')
if not self.check_should_save(epoch):
continue
# evaluate process.
self.save_duplicate_arch_pool('valid', epoch)
fitness_dict = self.evaluate(epoch, test_queue, fitnesses_dict=fitness_dict, train_queue=train_queue)
utils.save_checkpoint(model, optimizer, self.running_stats, self.exp_dir)
self.save_results(epoch, rank_details=True)
# add later, return the model specs that is evaluated across the time.
# Process the ranking in the end, return the best of training.
ep_k = [k for k in self.ranking_per_epoch.keys()][-1]
best_id = self.ranking_per_epoch[ep_k][-1][1].geno_id
return best_id, self.search_space.nasbench_model_specs[best_id]
示例3: train_and_evaluate
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import save_checkpoint [as 別名]
def train_and_evaluate(model, train_data, val_data, optimizer, scheduler, params, model_dir, restore_file=None):
"""Train the model and evaluate every epoch."""
# reload weights from restore_file if specified
if restore_file is not None:
restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
logging.info("Restoring parameters from {}".format(restore_path))
utils.load_checkpoint(restore_path, model, optimizer)
best_val_f1 = 0.0
patience_counter = 0
for epoch in range(1, params.epoch_num + 1):
# Run one epoch
logging.info("Epoch {}/{}".format(epoch, params.epoch_num))
# Compute number of batches in one epoch
params.train_steps = params.train_size // params.batch_size
params.val_steps = params.val_size // params.batch_size
# data iterator for training
train_data_iterator = data_loader.data_iterator(train_data, shuffle=True)
# Train for one epoch on training set
train(model, train_data_iterator, optimizer, scheduler, params)
# data iterator for evaluation
train_data_iterator = data_loader.data_iterator(train_data, shuffle=False)
val_data_iterator = data_loader.data_iterator(val_data, shuffle=False)
# Evaluate for one epoch on training set and validation set
params.eval_steps = params.train_steps
train_metrics = evaluate(model, train_data_iterator, params, mark='Train')
params.eval_steps = params.val_steps
val_metrics = evaluate(model, val_data_iterator, params, mark='Val')
val_f1 = val_metrics['f1']
improve_f1 = val_f1 - best_val_f1
# Save weights of the network
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
optimizer_to_save = optimizer.optimizer if args.fp16 else optimizer
utils.save_checkpoint({'epoch': epoch + 1,
'state_dict': model_to_save.state_dict(),
'optim_dict': optimizer_to_save.state_dict()},
is_best=improve_f1>0,
checkpoint=model_dir)
if improve_f1 > 0:
logging.info("- Found new best F1")
best_val_f1 = val_f1
if improve_f1 < params.patience:
patience_counter += 1
else:
patience_counter = 0
else:
patience_counter += 1
# Early stopping and logging best f1
if (patience_counter >= params.patience_num and epoch > params.min_epoch_num) or epoch == params.epoch_num:
logging.info("Best val f1: {:05.2f}".format(best_val_f1))
break