当前位置: 首页>>代码示例>>Python>>正文


Python Trainer.register_plugin方法代码示例

本文整理汇总了Python中trainer.Trainer.register_plugin方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer.register_plugin方法的具体用法?Python Trainer.register_plugin怎么用?Python Trainer.register_plugin使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在trainer.Trainer的用法示例。


在下文中一共展示了Trainer.register_plugin方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: main

# 需要导入模块: from trainer import Trainer [as 别名]
# 或者: from trainer.Trainer import register_plugin [as 别名]
def main(params):
    if params['load_dataset']:
        dataset = load_pkl(params['load_dataset'])
    elif params['dataset_class']:
        dataset = globals()[params['dataset_class']](**params[params['dataset_class']])
        if params['save_dataset']:
            save_pkl(params['save_dataset'], dataset)
    else:
        raise Exception('One of either load_dataset (path to pkl) or dataset_class needs to be specified.')
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
    stats_to_log = [
        'tick_stat',
        'kimg_stat',
    ]
    if params['progressive_growing']:
        stats_to_log.extend([
            'depth',
            'alpha',
            'lod',
            'minibatch_size'
        ])
    stats_to_log.extend([
        'time',
        'sec.tick',
        'sec.kimg'
    ] + losses)
    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log, [(1, 'epoch')])
    logger.log(params_to_str(params))
    if params['resume_network']:
        G, D = load_models(params['resume_network'], params['result_dir'], logger)
    else:
        G = Generator(dataset.shape, **params['Generator'])
        D = Discriminator(dataset.shape, **params['Discriminator'])
    if params['progressive_growing']:
        assert G.max_depth == D.max_depth
    G.cuda()
    D.cuda()
    latent_size = params['Generator']['latent_size']

    logger.log(str(G))
    logger.log('Total nuber of parameters in Generator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), G.parameters()))
    ))
    logger.log(str(D))
    logger.log('Total nuber of parameters in Discriminator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), D.parameters()))
    ))

    def get_dataloader(minibatch_size):
        return DataLoader(dataset, minibatch_size, sampler=InfiniteRandomSampler(dataset),
                          num_workers=params['num_data_workers'], pin_memory=False, drop_last=True)

    def rl(bs):
        return lambda: random_latents(bs, latent_size)

    # Setting up learning rate and optimizers
    opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
    opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0
    lr_scheduler_d = LambdaLR(opt_d, rampup)
    lr_scheduler_g = LambdaLR(opt_g, rampup)

    mb_def = params['minibatch_size']
    D_loss_fun = partial(wgan_gp_D_loss, return_all=True, iwass_lambda=params['iwass_lambda'],
                         iwass_epsilon=params['iwass_epsilon'], iwass_target=params['iwass_target'])
    G_loss_fun = wgan_gp_G_loss
    trainer = Trainer(D, G, D_loss_fun, G_loss_fun,
                      opt_d, opt_g, dataset, iter(get_dataloader(mb_def)), rl(mb_def), **params['Trainer'])
    # plugins
    if params['progressive_growing']:
        max_depth = min(G.max_depth, D.max_depth)
        trainer.register_plugin(DepthManager(get_dataloader, rl, max_depth, **params['DepthManager']))
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name))

    checkpoints_dir = params['checkpoints_dir'] if params['checkpoints_dir'] else result_dir
    trainer.register_plugin(SaverPlugin(checkpoints_dir, **params['SaverPlugin']))

    def subsitute_samples_path(d):
        return {k:(os.path.join(result_dir, v) if k == 'samples_path' else v) for k,v in d.items()}
    postprocessors = [ globals()[x](**subsitute_samples_path(params[x])) for x in params['postprocessors'] ]
    trainer.register_plugin(OutputGenerator(lambda x: random_latents(x, latent_size),
                                            postprocessors, **params['OutputGenerator']))
    trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
    trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
    trainer.register_plugin(logger)
    init_comet(params, trainer)
    trainer.run(params['total_kimg'])
    dataset.close()
开发者ID:codealphago,项目名称:pggan-pytorch,代码行数:99,代码来源:train.py


注:本文中的trainer.Trainer.register_plugin方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。