本文整理汇总了Python中trainer.Trainer.register_plugin方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer.register_plugin方法的具体用法?Python Trainer.register_plugin怎么用?Python Trainer.register_plugin使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类trainer.Trainer
的用法示例。
在下文中一共展示了Trainer.register_plugin方法的1个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: main
# 需要导入模块: from trainer import Trainer [as 别名]
# 或者: from trainer.Trainer import register_plugin [as 别名]
def main(params):
if params['load_dataset']:
dataset = load_pkl(params['load_dataset'])
elif params['dataset_class']:
dataset = globals()[params['dataset_class']](**params[params['dataset_class']])
if params['save_dataset']:
save_pkl(params['save_dataset'], dataset)
else:
raise Exception('One of either load_dataset (path to pkl) or dataset_class needs to be specified.')
result_dir = create_result_subdir(params['result_dir'], params['exp_name'])
losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
stats_to_log = [
'tick_stat',
'kimg_stat',
]
if params['progressive_growing']:
stats_to_log.extend([
'depth',
'alpha',
'lod',
'minibatch_size'
])
stats_to_log.extend([
'time',
'sec.tick',
'sec.kimg'
] + losses)
logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log, [(1, 'epoch')])
logger.log(params_to_str(params))
if params['resume_network']:
G, D = load_models(params['resume_network'], params['result_dir'], logger)
else:
G = Generator(dataset.shape, **params['Generator'])
D = Discriminator(dataset.shape, **params['Discriminator'])
if params['progressive_growing']:
assert G.max_depth == D.max_depth
G.cuda()
D.cuda()
latent_size = params['Generator']['latent_size']
logger.log(str(G))
logger.log('Total nuber of parameters in Generator: {}'.format(
sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), G.parameters()))
))
logger.log(str(D))
logger.log('Total nuber of parameters in Discriminator: {}'.format(
sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), D.parameters()))
))
def get_dataloader(minibatch_size):
return DataLoader(dataset, minibatch_size, sampler=InfiniteRandomSampler(dataset),
num_workers=params['num_data_workers'], pin_memory=False, drop_last=True)
def rl(bs):
return lambda: random_latents(bs, latent_size)
# Setting up learning rate and optimizers
opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])
def rampup(cur_nimg):
if cur_nimg < params['lr_rampup_kimg'] * 1000:
p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
return np.exp(-p * p * 5.0)
else:
return 1.0
lr_scheduler_d = LambdaLR(opt_d, rampup)
lr_scheduler_g = LambdaLR(opt_g, rampup)
mb_def = params['minibatch_size']
D_loss_fun = partial(wgan_gp_D_loss, return_all=True, iwass_lambda=params['iwass_lambda'],
iwass_epsilon=params['iwass_epsilon'], iwass_target=params['iwass_target'])
G_loss_fun = wgan_gp_G_loss
trainer = Trainer(D, G, D_loss_fun, G_loss_fun,
opt_d, opt_g, dataset, iter(get_dataloader(mb_def)), rl(mb_def), **params['Trainer'])
# plugins
if params['progressive_growing']:
max_depth = min(G.max_depth, D.max_depth)
trainer.register_plugin(DepthManager(get_dataloader, rl, max_depth, **params['DepthManager']))
for i, loss_name in enumerate(losses):
trainer.register_plugin(EfficientLossMonitor(i, loss_name))
checkpoints_dir = params['checkpoints_dir'] if params['checkpoints_dir'] else result_dir
trainer.register_plugin(SaverPlugin(checkpoints_dir, **params['SaverPlugin']))
def subsitute_samples_path(d):
return {k:(os.path.join(result_dir, v) if k == 'samples_path' else v) for k,v in d.items()}
postprocessors = [ globals()[x](**subsitute_samples_path(params[x])) for x in params['postprocessors'] ]
trainer.register_plugin(OutputGenerator(lambda x: random_latents(x, latent_size),
postprocessors, **params['OutputGenerator']))
trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
trainer.register_plugin(logger)
init_comet(params, trainer)
trainer.run(params['total_kimg'])
dataset.close()