本文整理匯總了Python中utils.load_checkpoint方法的典型用法代碼示例。如果您正苦於以下問題:Python utils.load_checkpoint方法的具體用法?Python utils.load_checkpoint怎麽用?Python utils.load_checkpoint使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類utils
的用法示例。
在下文中一共展示了utils.load_checkpoint方法的9個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: get_params
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def get_params(pretrained_model):
pretrained_checkpoint = load_checkpoint(pretrained_model)
for name, param in pretrained_checkpoint.items():
#for name, param in pretrained_checkpoint['state_dict'].items():
print('pretrained_model params name and size: ', name, param.size())
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
np.save(name+'.npy', param.cpu().numpy())
print('############# new_model load params name: ',name)
except:
raise RuntimeError('While copying the parameter named {}, \
whose dimensions in the model are {} and \
whose dimensions in the checkpoint are {}.'
.format(name, new_model_dict[name].size(), param.size()))
示例2: load_params
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def load_params(new_model, pretrained_model):
#new_model_dict = new_model.module.state_dict()
new_model_dict = new_model.state_dict()
pretrained_checkpoint = load_checkpoint(pretrained_model)
#for name, param in pretrained_checkpoint.items():
for name, param in pretrained_checkpoint['state_dict'].items():
print('pretrained_model params name and size: ', name, param.size())
if name in new_model_dict:
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
try:
new_model_dict[name].copy_(param)
print('############# new_model load params name: ',name)
except:
raise RuntimeError('While copying the parameter named {}, \
whose dimensions in the model are {} and \
whose dimensions in the checkpoint are {}.'
.format(name, new_model_dict[name].size(), param.size()))
else:
continue
示例3: main
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def main(args):
model = CNNVocoder(
n_heads=hparams.n_heads,
layer_channels=hparams.layer_channels,
pre_conv_channels=hparams.pre_conv_channels,
pre_residuals=hparams.pre_residuals,
up_residuals=hparams.up_residuals,
post_residuals=hparams.post_residuals
)
model = model.cuda()
model, _, _, _ = load_checkpoint(
args.model_path, model)
spec = np.load(args.spec_path)
spec = torch.FloatTensor(spec).unsqueeze(0).cuda()
t1 = time()
_, wav = model(spec)
dt = time() - t1
print('Synthesized audio in {}s'.format(dt))
wav = wav.data.cpu()[0].numpy()
audio.save_wav(wav, args.out_path)
示例4: train_and_eval
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def train_and_eval(net, train_loader, val_loader, optimizer, loss_fn, metrics, params, model_dir, restore=None):
"""
Train and evaluate every epoch of a model.
net: The model.
train/val loader: The data loaders
params: The parameters parsed from JSON file
restore: if there is a checkpoint restore from that point.
"""
best_val_acc = 0.0
if restore is not None:
restore_file = os.path.join(args.param_path, args.resume_path + '_pth.tar')
logging.info("Loaded checkpoints from:{}".format(restore_file))
utils.load_checkpoint(restore_file, net, optimizer)
for ep in range(params.num_epochs):
logging.info("Running epoch: {}/{}".format(ep+1, params.num_epochs))
# train one epoch
train(net, train_loader, loss_fn, params, metrics, optimizer)
val_metrics = evaluate(net, val_loader, loss_fn, params, metrics)
val_acc = val_metrics['accuracy']
isbest = val_acc >= best_val_acc
utils.save_checkpoint({"epoch":ep, "state_dict":net.state_dict(), "optimizer":optimizer.state_dict()},
isBest=isbest, ckpt_dir=model_dir)
if isbest:
# if the accuracy is great save it to best.json
logging.info("New best accuracy found!")
best_val_acc = val_acc
best_json_path = os.path.join(model_dir, "best_model_params.json")
utils.save_dict_to_json(val_metrics, best_json_path)
last_acc_path = os.path.join(model_dir, 'last_acc_metrics.json')
utils.save_dict_to_json(val_metrics, last_acc_path)
示例5: __init__
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def __init__(self,args):
# Define the network
#####################################################
self.Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm,
use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
self.Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm,
use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
self.Da = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids)
self.Db = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids)
utils.print_networks([self.Gab,self.Gba,self.Da,self.Db], ['Gab','Gba','Da','Db'])
# Define Loss criterias
self.MSE = nn.MSELoss()
self.L1 = nn.L1Loss()
# Optimizers
#####################################################
self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gab.parameters(),self.Gba.parameters()), lr=args.lr, betas=(0.5, 0.999))
self.d_optimizer = torch.optim.Adam(itertools.chain(self.Da.parameters(),self.Db.parameters()), lr=args.lr, betas=(0.5, 0.999))
self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
# Try loading checkpoint
#####################################################
if not os.path.isdir(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
try:
ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
self.start_epoch = ckpt['epoch']
self.Da.load_state_dict(ckpt['Da'])
self.Db.load_state_dict(ckpt['Db'])
self.Gab.load_state_dict(ckpt['Gab'])
self.Gba.load_state_dict(ckpt['Gba'])
self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
except:
print(' [*] No checkpoint!')
self.start_epoch = 0
示例6: test
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def test(args):
transform = transforms.Compose(
[transforms.Resize((args.crop_height,args.crop_width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
dataset_dirs = utils.get_testdata_link(args.dataset_dir)
a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)
a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
utils.print_networks([Gab,Gba], ['Gab','Gba'])
try:
ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
Gab.load_state_dict(ckpt['Gab'])
Gba.load_state_dict(ckpt['Gba'])
except:
print(' [*] No checkpoint!')
""" run """
a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])
Gab.eval()
Gba.eval()
with torch.no_grad():
a_fake_test = Gab(b_real_test)
b_fake_test = Gba(a_real_test)
a_recon_test = Gab(b_fake_test)
b_recon_test = Gba(a_fake_test)
pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0
if not os.path.isdir(args.results_dir):
os.makedirs(args.results_dir)
torchvision.utils.save_image(pic, args.results_dir+'/sample.jpg', nrow=3)
示例7: train_and_evaluate
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def train_and_evaluate(model, train_data, val_data, optimizer, scheduler, params, model_dir, restore_file=None):
"""Train the model and evaluate every epoch."""
# reload weights from restore_file if specified
if restore_file is not None:
restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
logging.info("Restoring parameters from {}".format(restore_path))
utils.load_checkpoint(restore_path, model, optimizer)
best_val_f1 = 0.0
patience_counter = 0
for epoch in range(1, params.epoch_num + 1):
# Run one epoch
logging.info("Epoch {}/{}".format(epoch, params.epoch_num))
# Compute number of batches in one epoch
params.train_steps = params.train_size // params.batch_size
params.val_steps = params.val_size // params.batch_size
# data iterator for training
train_data_iterator = data_loader.data_iterator(train_data, shuffle=True)
# Train for one epoch on training set
train(model, train_data_iterator, optimizer, scheduler, params)
# data iterator for evaluation
train_data_iterator = data_loader.data_iterator(train_data, shuffle=False)
val_data_iterator = data_loader.data_iterator(val_data, shuffle=False)
# Evaluate for one epoch on training set and validation set
params.eval_steps = params.train_steps
train_metrics = evaluate(model, train_data_iterator, params, mark='Train')
params.eval_steps = params.val_steps
val_metrics = evaluate(model, val_data_iterator, params, mark='Val')
val_f1 = val_metrics['f1']
improve_f1 = val_f1 - best_val_f1
# Save weights of the network
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
optimizer_to_save = optimizer.optimizer if args.fp16 else optimizer
utils.save_checkpoint({'epoch': epoch + 1,
'state_dict': model_to_save.state_dict(),
'optim_dict': optimizer_to_save.state_dict()},
is_best=improve_f1>0,
checkpoint=model_dir)
if improve_f1 > 0:
logging.info("- Found new best F1")
best_val_f1 = val_f1
if improve_f1 < params.patience:
patience_counter += 1
else:
patience_counter = 0
else:
patience_counter += 1
# Early stopping and logging best f1
if (patience_counter >= params.patience_num and epoch > params.min_epoch_num) or epoch == params.epoch_num:
logging.info("Best val f1: {:05.2f}".format(best_val_f1))
break
示例8: main
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def main():
args = parse_args()
C = importlib.import_module(args.config).TrainConfig
print("MODEL ID: {}".format(C.model_id))
summary_writer = SummaryWriter(C.log_dpath)
train_iter, val_iter, test_iter, vocab = build_loaders(C)
model = build_model(C, vocab)
optimizer = torch.optim.Adam(model.parameters(), lr=C.lr, weight_decay=C.weight_decay, amsgrad=True)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=C.lr_decay_gamma,
patience=C.lr_decay_patience, verbose=True)
best_val_scores = { 'CIDEr': 0. }
best_epoch = 0
best_ckpt_fpath = None
for e in range(1, C.epochs + 1):
ckpt_fpath = C.ckpt_fpath_tpl.format(e)
""" Train """
print("\n")
train_loss = train(e, model, optimizer, train_iter, vocab, C.decoder.rnn_teacher_forcing_ratio,
C.reg_lambda, C.recon_lambda, C.gradient_clip)
log_train(C, summary_writer, e, train_loss, get_lr(optimizer))
""" Validation """
val_loss = test(model, val_iter, vocab, C.reg_lambda, C.recon_lambda)
val_scores = evaluate(val_iter, model, model.vocab)
log_val(C, summary_writer, e, val_loss, val_scores)
if e >= C.save_from and e % C.save_every == 0:
print("Saving checkpoint at epoch={} to {}".format(e, ckpt_fpath))
save_checkpoint(e, model, ckpt_fpath, C)
if e >= C.lr_decay_start_from:
lr_scheduler.step(val_loss['total'])
if e == 1 or val_scores['CIDEr'] > best_val_scores['CIDEr']:
best_epoch = e
best_val_scores = val_scores
best_ckpt_fpath = ckpt_fpath
""" Test with Best Model """
print("\n\n\n[BEST]")
best_model = load_checkpoint(model, best_ckpt_fpath)
test_scores = evaluate(test_iter, best_model, best_model.vocab)
log_test(C, summary_writer, best_epoch, test_scores)
save_checkpoint(best_epoch, best_model, C.ckpt_fpath_tpl.format("best"), C)
示例9: __init__
# 需要導入模塊: import utils [as 別名]
# 或者: from utils import load_checkpoint [as 別名]
def __init__(self, args):
if args.dataset == 'voc2012':
self.n_channels = 21
elif args.dataset == 'cityscapes':
self.n_channels = 20
elif args.dataset == 'acdc':
self.n_channels = 4
# Define the network
self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm,
use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids) # for image to segmentation
### Now we put in the pretrained weights in Gsi
### These will only be used in the case of VOC and cityscapes
if args.dataset != 'acdc':
saved_state_dict = torch.load(pretrained_loc)
new_params = self.Gsi.state_dict().copy()
for name, param in new_params.items():
# print(name)
if name in saved_state_dict and param.size() == saved_state_dict[name].size():
new_params[name].copy_(saved_state_dict[name])
# print('copy {}'.format(name))
# self.Gsi.load_state_dict(new_params)
utils.print_networks([self.Gsi], ['Gsi'])
###Defining an interpolation function so as to match the output of network to feature map size
self.interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True)
self.interp_val = nn.Upsample(size = (512, 512), mode='bilinear', align_corners=True)
self.CE = nn.CrossEntropyLoss()
self.activation_softmax = nn.Softmax2d()
self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999))
### writer for tensorboard
self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised')
self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset)
self.args = args
if not os.path.isdir(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
try:
ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
self.start_epoch = ckpt['epoch']
self.Gsi.load_state_dict(ckpt['Gsi'])
self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer'])
self.best_iou = ckpt['best_iou']
except:
print(' [*] No checkpoint!')
self.start_epoch = 0
self.best_iou = -100