当前位置: 首页>>代码示例>>Python>>正文


Python utils.load_checkpoint方法代码示例

本文整理汇总了Python中utils.load_checkpoint方法的典型用法代码示例。如果您正苦于以下问题:Python utils.load_checkpoint方法的具体用法?Python utils.load_checkpoint怎么用?Python utils.load_checkpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在utils的用法示例。


在下文中一共展示了utils.load_checkpoint方法的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_params

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def get_params(pretrained_model):
    pretrained_checkpoint = load_checkpoint(pretrained_model)
    for name, param in pretrained_checkpoint.items():
    #for name, param in pretrained_checkpoint['state_dict'].items():
        print('pretrained_model params name and size: ', name, param.size())
        if isinstance(param, Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        try:
            np.save(name+'.npy', param.cpu().numpy())
            print('############# new_model load params name: ',name)
        except:
            raise RuntimeError('While copying the parameter named {}, \
                               whose dimensions in the model are {} and \
                               whose dimensions in the checkpoint are {}.'
                               .format(name, new_model_dict[name].size(), param.size())) 
开发者ID:aliyun,项目名称:alibabacloud-quantization-networks,代码行数:18,代码来源:main.py

示例2: load_params

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def load_params(new_model, pretrained_model):
    #new_model_dict = new_model.module.state_dict()
    new_model_dict = new_model.state_dict()
    pretrained_checkpoint = load_checkpoint(pretrained_model)
    #for name, param in pretrained_checkpoint.items():
    for name, param in pretrained_checkpoint['state_dict'].items():
        print('pretrained_model params name and size: ', name, param.size())
        if name in new_model_dict:
            if isinstance(param, Parameter):
                # backwards compatibility for serialized parameters
                param = param.data
            try:
                new_model_dict[name].copy_(param)
                print('############# new_model load params name: ',name)
            except:
                raise RuntimeError('While copying the parameter named {}, \
                                   whose dimensions in the model are {} and \
                                   whose dimensions in the checkpoint are {}.'
                                   .format(name, new_model_dict[name].size(), param.size()))
        else:
            continue 
开发者ID:aliyun,项目名称:alibabacloud-quantization-networks,代码行数:23,代码来源:quan_all_main.py

示例3: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def main(args):
    model = CNNVocoder(
        n_heads=hparams.n_heads,
        layer_channels=hparams.layer_channels,
        pre_conv_channels=hparams.pre_conv_channels,
        pre_residuals=hparams.pre_residuals,
        up_residuals=hparams.up_residuals,
        post_residuals=hparams.post_residuals
    )
    model = model.cuda()

    model, _, _, _ = load_checkpoint(
            args.model_path, model)
    spec = np.load(args.spec_path)
    spec = torch.FloatTensor(spec).unsqueeze(0).cuda()
    t1 = time()
    _, wav = model(spec)
    dt = time() - t1 
    print('Synthesized audio in {}s'.format(dt))
    wav = wav.data.cpu()[0].numpy()
    audio.save_wav(wav, args.out_path) 
开发者ID:tuan3w,项目名称:cnn_vocoder,代码行数:23,代码来源:synthesis.py

示例4: train_and_eval

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def train_and_eval(net, train_loader, val_loader, optimizer, loss_fn, metrics, params, model_dir, restore=None):
    """
    Train and evaluate every epoch of a model.
    net: The model. 
    train/val loader: The data loaders
    params: The parameters parsed from JSON file 
    restore: if there is a checkpoint restore from that point. 
    """
    best_val_acc = 0.0 
    if restore is not None:
        restore_file = os.path.join(args.param_path, args.resume_path + '_pth.tar')
        logging.info("Loaded checkpoints from:{}".format(restore_file))
        utils.load_checkpoint(restore_file, net, optimizer)

    for ep in range(params.num_epochs):
        logging.info("Running epoch: {}/{}".format(ep+1, params.num_epochs))

        # train one epoch 
        train(net, train_loader, loss_fn, params, metrics, optimizer)

        val_metrics = evaluate(net, val_loader, loss_fn, params, metrics)

        val_acc = val_metrics['accuracy']
        isbest = val_acc >= best_val_acc 

        utils.save_checkpoint({"epoch":ep, "state_dict":net.state_dict(), "optimizer":optimizer.state_dict()}, 
        isBest=isbest, ckpt_dir=model_dir)
    
        if isbest:
            # if the accuracy is great  save it to best.json 
            logging.info("New best accuracy found!")
            best_val_acc = val_acc 
            best_json_path = os.path.join(model_dir, "best_model_params.json")
            utils.save_dict_to_json(val_metrics, best_json_path)
        
        last_acc_path = os.path.join(model_dir, 'last_acc_metrics.json')
        utils.save_dict_to_json(val_metrics, last_acc_path) 
开发者ID:aicaffeinelife,项目名称:Pytorch-STN,代码行数:39,代码来源:train.py

示例5: __init__

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def __init__(self,args):

        # Define the network 
        #####################################################
        self.Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 
                                                    use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
        self.Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 
                                                    use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
        self.Da = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids)
        self.Db = define_Dis(input_nc=3, ndf=args.ndf, netD= args.dis_net, n_layers_D=3, norm=args.norm, gpu_ids=args.gpu_ids)

        utils.print_networks([self.Gab,self.Gba,self.Da,self.Db], ['Gab','Gba','Da','Db'])

        # Define Loss criterias

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(self.Gab.parameters(),self.Gba.parameters()), lr=args.lr, betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(self.Da.parameters(),self.Db.parameters()), lr=args.lr, betas=(0.5, 0.999))
        

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0 
开发者ID:arnab39,项目名称:cycleGAN-PyTorch,代码行数:46,代码来源:model.py

示例6: test

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def test(args):

    transform = transforms.Compose(
        [transforms.Resize((args.crop_height,args.crop_width)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
    b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)


    a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)

    Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, 
                                                    use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
    Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, 
                                                    use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)

    utils.print_networks([Gab,Gba], ['Gab','Gba'])

    try:
        ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
        Gab.load_state_dict(ckpt['Gab'])
        Gba.load_state_dict(ckpt['Gba'])
    except:
        print(' [*] No checkpoint!')


    """ run """
    a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
    b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
    a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])
            

    Gab.eval()
    Gba.eval()

    with torch.no_grad():
        a_fake_test = Gab(b_real_test)
        b_fake_test = Gba(a_real_test)
        a_recon_test = Gab(b_fake_test)
        b_recon_test = Gba(a_fake_test)

    pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0

    if not os.path.isdir(args.results_dir):
        os.makedirs(args.results_dir)

    torchvision.utils.save_image(pic, args.results_dir+'/sample.jpg', nrow=3) 
开发者ID:arnab39,项目名称:cycleGAN-PyTorch,代码行数:54,代码来源:test.py

示例7: train_and_evaluate

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def train_and_evaluate(model, train_data, val_data, optimizer, scheduler, params, model_dir, restore_file=None):
    """Train the model and evaluate every epoch."""
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
        
    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(1, params.epoch_num + 1):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch, params.epoch_num))

        # Compute number of batches in one epoch
        params.train_steps = params.train_size // params.batch_size
        params.val_steps = params.val_size // params.batch_size

        # data iterator for training
        train_data_iterator = data_loader.data_iterator(train_data, shuffle=True)
        # Train for one epoch on training set
        train(model, train_data_iterator, optimizer, scheduler, params)

        # data iterator for evaluation
        train_data_iterator = data_loader.data_iterator(train_data, shuffle=False)
        val_data_iterator = data_loader.data_iterator(val_data, shuffle=False)

        # Evaluate for one epoch on training set and validation set
        params.eval_steps = params.train_steps
        train_metrics = evaluate(model, train_data_iterator, params, mark='Train')
        params.eval_steps = params.val_steps
        val_metrics = evaluate(model, val_data_iterator, params, mark='Val')
        
        val_f1 = val_metrics['f1']
        improve_f1 = val_f1 - best_val_f1

        # Save weights of the network
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        optimizer_to_save = optimizer.optimizer if args.fp16 else optimizer
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model_to_save.state_dict(),
                               'optim_dict': optimizer_to_save.state_dict()},
                               is_best=improve_f1>0,
                               checkpoint=model_dir)
        if improve_f1 > 0:
            logging.info("- Found new best F1")
            best_val_f1 = val_f1
            if improve_f1 < params.patience:
                patience_counter += 1
            else:
                patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping and logging best f1
        if (patience_counter >= params.patience_num and epoch > params.min_epoch_num) or epoch == params.epoch_num:
            logging.info("Best val f1: {:05.2f}".format(best_val_f1))
            break 
开发者ID:lemonhu,项目名称:NER-BERT-pytorch,代码行数:61,代码来源:train.py

示例8: main

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def main():
    args = parse_args()
    C = importlib.import_module(args.config).TrainConfig
    print("MODEL ID: {}".format(C.model_id))

    summary_writer = SummaryWriter(C.log_dpath)

    train_iter, val_iter, test_iter, vocab = build_loaders(C)

    model = build_model(C, vocab)

    optimizer = torch.optim.Adam(model.parameters(), lr=C.lr, weight_decay=C.weight_decay, amsgrad=True)
    lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=C.lr_decay_gamma,
                                     patience=C.lr_decay_patience, verbose=True)

    best_val_scores = { 'CIDEr': 0. }
    best_epoch = 0
    best_ckpt_fpath = None
    for e in range(1, C.epochs + 1):
        ckpt_fpath = C.ckpt_fpath_tpl.format(e)

        """ Train """
        print("\n")
        train_loss = train(e, model, optimizer, train_iter, vocab, C.decoder.rnn_teacher_forcing_ratio,
                           C.reg_lambda, C.recon_lambda, C.gradient_clip)
        log_train(C, summary_writer, e, train_loss, get_lr(optimizer))

        """ Validation """
        val_loss = test(model, val_iter, vocab, C.reg_lambda, C.recon_lambda)
        val_scores = evaluate(val_iter, model, model.vocab)
        log_val(C, summary_writer, e, val_loss, val_scores)

        if e >= C.save_from and e % C.save_every == 0:
            print("Saving checkpoint at epoch={} to {}".format(e, ckpt_fpath))
            save_checkpoint(e, model, ckpt_fpath, C)

        if e >= C.lr_decay_start_from:
            lr_scheduler.step(val_loss['total'])
        if e == 1 or val_scores['CIDEr'] > best_val_scores['CIDEr']:
            best_epoch = e
            best_val_scores = val_scores
            best_ckpt_fpath = ckpt_fpath

    """ Test with Best Model """
    print("\n\n\n[BEST]")
    best_model = load_checkpoint(model, best_ckpt_fpath)
    test_scores = evaluate(test_iter, best_model, best_model.vocab)
    log_test(C, summary_writer, best_epoch, test_scores)
    save_checkpoint(best_epoch, best_model, C.ckpt_fpath_tpl.format("best"), C) 
开发者ID:hobincar,项目名称:RecNet,代码行数:51,代码来源:train.py

示例9: __init__

# 需要导入模块: import utils [as 别名]
# 或者: from utils import load_checkpoint [as 别名]
def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network 
        self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm,
                              use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)  # for image to segmentation

        ### Now we put in the pretrained weights in Gsi
        ### These will only be used in the case of VOC and cityscapes
        if args.dataset != 'acdc':
            saved_state_dict = torch.load(pretrained_loc)
            new_params = self.Gsi.state_dict().copy()
            for name, param in new_params.items():
                # print(name)
                if name in saved_state_dict and param.size() == saved_state_dict[name].size():
                    new_params[name].copy_(saved_state_dict[name])
                    # print('copy {}'.format(name))
            # self.Gsi.load_state_dict(new_params)

        utils.print_networks([self.Gsi], ['Gsi'])

        ###Defining an interpolation function so as to match the output of network to feature map size
        self.interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True)
        self.interp_val = nn.Upsample(size = (512, 512), mode='bilinear', align_corners=True)

        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()
        self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999))

        ### writer for tensorboard
        self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised')
        self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset)

        self.args = args

        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100 
开发者ID:arnab39,项目名称:Semi-supervised-segmentation-cycleGAN,代码行数:56,代码来源:model.py


注:本文中的utils.load_checkpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。