当前位置: 首页>>代码示例>>Python>>正文


Python Trainer.run方法代码示例

本文整理汇总了Python中trainer.Trainer.run方法的典型用法代码示例。如果您正苦于以下问题:Python Trainer.run方法的具体用法?Python Trainer.run怎么用?Python Trainer.run使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在trainer.Trainer的用法示例。


在下文中一共展示了Trainer.run方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from trainer import Trainer [as 别名]
# 或者: from trainer.Trainer import run [as 别名]
def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["dcnet"]

    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
    valid_loader = uttloader(
        config_dict["valid_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches, "default checkpoint"
        if checkpoint is None else checkpoint))

    dcnet = DCNet(num_bins, **dcnnet_conf)
    trainer = Trainer(dcnet, **config_dict["trainer"])
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
开发者ID:jhuiac,项目名称:deep-clustering,代码行数:36,代码来源:train_dcnet.py

示例2: MNIST

# 需要导入模块: from trainer import Trainer [as 别名]
# 或者: from trainer.Trainer import run [as 别名]
from sc import SC
from trainer import Trainer
from pylearn2.models.mlp import MLP, Linear
from pylearn2.datasets.mnist import MNIST

print 'Loading dataset'
dataset = MNIST(
        which_set = 'train',
        center = True
                )

hidden_size = 100
input_size = dataset.X_space.dim

print 'Creating model'
model = MLP(
            batch_size = 100,
            nvis = hidden_size,
            layers = [
                Linear(dim=input_size, layer_name='l1', irange=.05)
                ]
            )

print 'Creating Sparse Coder'
sc = SC(model, hidden_size, dataset.X_space)

print 'Training...'
trainer = Trainer(sc, dataset)

trainer.run()
开发者ID:nagyistge,项目名称:kindergarden-dropout,代码行数:32,代码来源:experiment.py

示例3: main

# 需要导入模块: from trainer import Trainer [as 别名]
# 或者: from trainer.Trainer import run [as 别名]
def main(params):
    if params['load_dataset']:
        dataset = load_pkl(params['load_dataset'])
    elif params['dataset_class']:
        dataset = globals()[params['dataset_class']](**params[params['dataset_class']])
        if params['save_dataset']:
            save_pkl(params['save_dataset'], dataset)
    else:
        raise Exception('One of either load_dataset (path to pkl) or dataset_class needs to be specified.')
    result_dir = create_result_subdir(params['result_dir'], params['exp_name'])

    losses = ['G_loss', 'D_loss', 'D_real', 'D_fake']
    stats_to_log = [
        'tick_stat',
        'kimg_stat',
    ]
    if params['progressive_growing']:
        stats_to_log.extend([
            'depth',
            'alpha',
            'lod',
            'minibatch_size'
        ])
    stats_to_log.extend([
        'time',
        'sec.tick',
        'sec.kimg'
    ] + losses)
    logger = TeeLogger(os.path.join(result_dir, 'log.txt'), stats_to_log, [(1, 'epoch')])
    logger.log(params_to_str(params))
    if params['resume_network']:
        G, D = load_models(params['resume_network'], params['result_dir'], logger)
    else:
        G = Generator(dataset.shape, **params['Generator'])
        D = Discriminator(dataset.shape, **params['Discriminator'])
    if params['progressive_growing']:
        assert G.max_depth == D.max_depth
    G.cuda()
    D.cuda()
    latent_size = params['Generator']['latent_size']

    logger.log(str(G))
    logger.log('Total nuber of parameters in Generator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), G.parameters()))
    ))
    logger.log(str(D))
    logger.log('Total nuber of parameters in Discriminator: {}'.format(
        sum(map(lambda x: reduce(lambda a, b: a*b, x.size()), D.parameters()))
    ))

    def get_dataloader(minibatch_size):
        return DataLoader(dataset, minibatch_size, sampler=InfiniteRandomSampler(dataset),
                          num_workers=params['num_data_workers'], pin_memory=False, drop_last=True)

    def rl(bs):
        return lambda: random_latents(bs, latent_size)

    # Setting up learning rate and optimizers
    opt_g = Adam(G.parameters(), params['G_lr_max'], **params['Adam'])
    opt_d = Adam(D.parameters(), params['D_lr_max'], **params['Adam'])

    def rampup(cur_nimg):
        if cur_nimg < params['lr_rampup_kimg'] * 1000:
            p = max(0.0, 1 - cur_nimg / (params['lr_rampup_kimg'] * 1000))
            return np.exp(-p * p * 5.0)
        else:
            return 1.0
    lr_scheduler_d = LambdaLR(opt_d, rampup)
    lr_scheduler_g = LambdaLR(opt_g, rampup)

    mb_def = params['minibatch_size']
    D_loss_fun = partial(wgan_gp_D_loss, return_all=True, iwass_lambda=params['iwass_lambda'],
                         iwass_epsilon=params['iwass_epsilon'], iwass_target=params['iwass_target'])
    G_loss_fun = wgan_gp_G_loss
    trainer = Trainer(D, G, D_loss_fun, G_loss_fun,
                      opt_d, opt_g, dataset, iter(get_dataloader(mb_def)), rl(mb_def), **params['Trainer'])
    # plugins
    if params['progressive_growing']:
        max_depth = min(G.max_depth, D.max_depth)
        trainer.register_plugin(DepthManager(get_dataloader, rl, max_depth, **params['DepthManager']))
    for i, loss_name in enumerate(losses):
        trainer.register_plugin(EfficientLossMonitor(i, loss_name))

    checkpoints_dir = params['checkpoints_dir'] if params['checkpoints_dir'] else result_dir
    trainer.register_plugin(SaverPlugin(checkpoints_dir, **params['SaverPlugin']))

    def subsitute_samples_path(d):
        return {k:(os.path.join(result_dir, v) if k == 'samples_path' else v) for k,v in d.items()}
    postprocessors = [ globals()[x](**subsitute_samples_path(params[x])) for x in params['postprocessors'] ]
    trainer.register_plugin(OutputGenerator(lambda x: random_latents(x, latent_size),
                                            postprocessors, **params['OutputGenerator']))
    trainer.register_plugin(AbsoluteTimeMonitor(params['resume_time']))
    trainer.register_plugin(LRScheduler(lr_scheduler_d, lr_scheduler_g))
    trainer.register_plugin(logger)
    init_comet(params, trainer)
    trainer.run(params['total_kimg'])
    dataset.close()
开发者ID:codealphago,项目名称:pggan-pytorch,代码行数:99,代码来源:train.py

示例4:

# 需要导入模块: from trainer import Trainer [as 别名]
# 或者: from trainer.Trainer import run [as 别名]
#trainerSub.setWeightExpression("Weight_XS")
##trainerSub.run(variablesSub, 2000,0.001,0.5, 40,2)
#trainerSub.run(variablesSub,SubBDTOptions)
#f.write("SubBDT Parameters:\n")
#f.write(SubBDTOptions+"\n\n")
#samplesTrainingSub=[Sample("tth",trainingPath+"tth_nominal.root",-1,1.),Sample("ttbar_nominal",trainingPath+"Training_ttbar_bb.root",.5,1.),Sample("ttbar_b",trainingPath+"Training_ttbar_b.root",.5,1.),Sample("ttbar_cc",trainingPath+"Training_ttbar_cc.root",.5,1.),Sample("ttbar_light",trainingPath+"Training_ttbar_l.root",.5,1.),Sample("ttbar",trainingPath+"Training_ttbar.root",.1,1.)]
#print "Writing SubBDT output"
#weightfileSub=trainerSub.weightfile
#evaluaterSub=Evaluater(weightfileSub,variablesSub,samplesTrainingSub,[],[])
#evaluaterSub.WriteBDTVars("","_SubBDT","Sub")

print "Training Final BDT"
trainerFinal=Trainer(trainingPath+"tth_nominal.root",trainingPath+"ttbar_nominal.root",evaluationPath+"tth_nominal.root",evaluationPath+"ttbar_nominal.root",FinalBDTOptions,variablesFinal,[],False,"weights/weights_Final_"+category+"_"+name+".xml")
trainerFinal.useTransformations(False)
trainerFinal.setWeightExpression("Weight_XS")
trainerFinal.run(variablesFinal,FinalBDTOptions)

f.write("FinalBDT Parameters:\n")
f.write(FinalBDTOptions+"\n\n")
f.close()

samplesTrainingFinal=[Sample("tth",trainingPath+"tth_nominal.root",-1,1.),Sample("ttbar_light",trainingPath+"ttbar_l_nominal.root",.1,1.),Sample("ttbar_b",trainingPath+"ttbar_b_nominal.root",.5,1.),Sample("ttbar_bb",trainingPath+"ttbar_bb_nominal.root",.5,1.),Sample("ttbar_2b",trainingPath+"ttbar_2b_nominal.root",.5,1.),Sample("ttbar_cc",trainingPath+"ttbar_cc_nominal.root",.5,1.),Sample("ttbar",trainingPath+"ttbar_nominal.root",.1,1.),Sample("SingleT",trainingPath+"SingleT_nominal.root",0.031,1.),Sample("DiBoson",trainingPath+"DiBoson_nominal.root",0.07,1.),Sample("ttW",trainingPath+"ttW_nominal.root",0.139,1.),Sample("ttZ",trainingPath+"ttZ_nominal.root",0.124,1.)]
dataTrain=Sample("data",trainingPath+"MCData.root",-1,1.)

print "Writing FinalBDT output"
weightfileFinal=trainerFinal.weightfile
evaluaterFinal=Evaluater(weightfileFinal,variablesFinal,samplesTrainingFinal,[],[])
evaluaterFinal.WriteBDTVars("","_FinalBDT","Final")

## plot scripts deprecated
#print "plotting BDTTraining"
开发者ID:kit-cn-cms,项目名称:TTH_Analysis_Chain,代码行数:33,代码来源:TrainAndEval.py


注:本文中的trainer.Trainer.run方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。