当前位置: 首页>>代码示例>>Python>>正文


Python chainer.training方法代码示例

本文整理汇总了Python中chainer.training方法的典型用法代码示例。如果您正苦于以下问题:Python chainer.training方法的具体用法?Python chainer.training怎么用?Python chainer.training使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer的用法示例。


在下文中一共展示了chainer.training方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __call__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def __call__(self, trainer):
        duration_epoch = time.time() - self.time_epoch
        epoch = trainer.updater.epoch
        loss = trainer.observation["main/loss"].data
        logger.info("epoch: %s, duration: %ds, loss: %.6g.",
                    epoch, duration_epoch, loss)

        # get rnn state
        model = trainer.updater.get_optimizer("main").target
        state = model.predictor.get_state()
        # generate text
        seed = generate_seed(self.text)
        generate_text(model, seed)
        # set rnn back to training state
        model.predictor.set_state(state)

        # reset time
        self.time_epoch = time.time() 
开发者ID:yxtay,项目名称:char-rnn-text-generation,代码行数:20,代码来源:chainer_model.py

示例2: create_trainer

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def create_trainer(args, model):
    # Setup an optimizer
    # optimizer = chainer.optimizers.Adam()
    optimizer = chainer.optimizers.SGD()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = MyIterator(train, args.batchsize, shuffle=False)
    test_iter = MyIterator(test, args.batchsize, repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    return trainer 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:22,代码来源:gen_mnist_mlp.py

示例3: get_model

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def get_model(args, loss_func, vocab, vocab_ngram_tokens, current_utils=utils.word):
    model = None
    if args.subword == 'none':
        if args.model == 'skipgram':
            model = current_utils.SkipGram(vocab.cnt_words, args.dimensions, loss_func)
        if args.model == 'cbow':
            # todo only skipgram supported
            model = current_utils.ContinuousBoW(vocab.cnt_words, args.dimensions, loss_func)
    else:
        if args.model == 'skipgram':
            model = utils.subword.SkipGram(args.subword, vocab, vocab_ngram_tokens, args.dimensions, loss_func, )

    if model is None:
        raise Exception('Unknown model and word/subword type: {} "and" {}'.format(args.model, args.subword))
    return model


#@training.make_extension(trigger=(1, 'epoch'))
#def dump_embs(trainer):
#    print("dumping embeddings") 
开发者ID:vecto-ai,项目名称:vecto,代码行数:22,代码来源:train_word2vec.py

示例4: add_default_arguments

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def add_default_arguments(parser):
    parser.add_argument("log_dir", help='directory where generated models and logs shall be stored')
    parser.add_argument('-b', '--batch-size', dest='batch_size', type=int, required=True,
                        help="Number of images per training batch")
    parser.add_argument('-g', '--gpus', type=int, nargs="*", default=[], help="Ids of GPU to use [default: (use cpu)]")
    parser.add_argument('-e', '--epochs', type=int, default=20, help="Number of epochs to train [default: 20]")
    parser.add_argument('-r', '--resume', help="path to previously saved state of trained model from which training shall resume")
    parser.add_argument('-si', '--snapshot-interval', dest='snapshot_interval', type=int, default=20000,
                        help="number of iterations after which a snapshot shall be taken [default: 20000]")
    parser.add_argument('-ln', '--log-name', dest='log_name', default='training', help="name of the log folder")
    parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=0.01,
                        help="initial learning rate [default: 0.01]")
    parser.add_argument('-li', '--log-interval', dest='log_interval', type=int, default=100,
                        help="number of iterations after which an update shall be logged [default: 100]")
    parser.add_argument('--lr-step', dest='learning_rate_step_size', type=float, default=0.1,
                        help="Step size for decreasing learning rate [default: 0.1]")
    parser.add_argument('-t', '--test-interval', dest='test_interval', type=int, default=1000,
                        help="number of iterations after which testing should be performed [default: 1000]")
    parser.add_argument('--test-iterations', dest='test_iterations', type=int, default=200,
                        help="number of test iterations [default: 200]")
    parser.add_argument("-dr", "--dropout-ratio", dest='dropout_ratio', default=0.5, type=float,
                        help="ratio for dropout layers")

    return parser 
开发者ID:Bartzi,项目名称:see,代码行数:26,代码来源:train_utils.py

示例5: adadelta_eps_decay

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def adadelta_eps_decay(eps_decay):
    """Extension to perform adadelta eps decay.

    Args:
        eps_decay (float): Decay rate of eps.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adadelta_eps_decay(trainer):
        _adadelta_eps_decay(trainer, eps_decay)

    return adadelta_eps_decay 
开发者ID:espnet,项目名称:espnet,代码行数:18,代码来源:asr_utils.py

示例6: adam_lr_decay

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def adam_lr_decay(eps_decay):
    """Extension to perform adam lr decay.

    Args:
        eps_decay (float): Decay rate of lr.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adam_lr_decay(trainer):
        _adam_lr_decay(trainer, eps_decay)

    return adam_lr_decay 
开发者ID:espnet,项目名称:espnet,代码行数:18,代码来源:asr_utils.py

示例7: setup_updater

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def setup_updater(mode, gpus, train_iter, optimizer):
    gpu0 = gpus[0]
    if len(gpus) == 1:
        # Single GPU or CPU
        logger.info('Setup single updater (gpu: %d)', gpu0)
        updater = training.StandardUpdater(train_iter, optimizer, device=gpu0,
                                           converter=select_converter(mode))
    else:
        # Multiple GPUs
        logger.info('Setup parallel updater (gpu: %s)', str(gpus))
        devs = {'slave{}'.format(i): gpu for i, gpu in enumerate(gpus[1:])}
        devs['main'] = gpu0
        updater = training.updaters.MultiprocessParallelUpdater(
            train_iter, optimizer, devices=devs,
            converter=select_converter(mode))
    return updater 
开发者ID:takiyu,项目名称:portrait_matting,代码行数:18,代码来源:train.py

示例8: test_chainer_pruning_extension

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def test_chainer_pruning_extension() -> None:
    def objective(trial: optuna.trial.Trial) -> float:

        model = L.Classifier(chainer.Sequential(L.Linear(None, 2)))
        optimizer = chainer.optimizers.Adam()
        optimizer.setup(model)

        train_iter = chainer.iterators.SerialIterator(FixedValueDataset(), 16)
        updater = chainer.training.StandardUpdater(train_iter, optimizer)
        trainer = chainer.training.Trainer(updater, (1, "epoch"))
        trainer.extend(
            optuna.integration.chainer.ChainerPruningExtension(trial, "main/loss", (1, "epoch"))
        )

        trainer.run(show_loop_exception_msg=False)
        return 1.0

    study = optuna.create_study(pruner=DeterministicPruner(True))
    study.optimize(objective, n_trials=1)
    assert study.trials[0].state == optuna.trial.TrialState.PRUNED

    study = optuna.create_study(pruner=DeterministicPruner(False))
    study.optimize(objective, n_trials=1)
    assert study.trials[0].state == optuna.trial.TrialState.COMPLETE
    assert study.trials[0].value == 1.0 
开发者ID:optuna,项目名称:optuna,代码行数:27,代码来源:test_chainer.py

示例9: __init__

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def __init__(self, trial, observation_key, pruner_trigger):
        # type: (optuna.trial.Trial, str, TriggerType) -> None

        _imports.check()

        self._trial = trial
        self._observation_key = observation_key
        self._pruner_trigger = chainer.training.get_trigger(pruner_trigger)
        if not (
            isinstance(self._pruner_trigger, triggers.IntervalTrigger)
            or isinstance(self._pruner_trigger, triggers.ManualScheduleTrigger)
        ):
            pruner_type = type(self._pruner_trigger)
            raise TypeError(
                "Invalid trigger class: " + str(pruner_type) + "\n"
                "Pruner trigger is supposed to be an instance of "
                "IntervalTrigger or ManualScheduleTrigger."
            ) 
开发者ID:optuna,项目名称:optuna,代码行数:20,代码来源:chainer.py

示例10: ae_reconstruction

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def ae_reconstruction(enc, dec, eval_folder, gpu, data_iter, batch_size=32, img_chan=3, img_size=64):
    @chainer.training.make_extension()
    def sample_reconstruction(trainer):
        xp = enc.xp
        batch = data_iter.next()
        d_real = xp.zeros((batch_size, img_chan, img_size, img_size)).astype("f")
        for i in range(batch_size):
            d_real[i, :] = xp.asarray(batch[i])
        x = Variable(d_real, volatile=True)
        imgs = dec(enc(x, test=True), test=True)
        save_images_grid(imgs, path=eval_folder+"/iter_"+str(trainer.updater.iteration)+".rec.jpg",
            grid_w=batch_size//8, grid_h=8)
        save_images_grid(d_real, path=eval_folder+"/iter_"+str(trainer.updater.iteration)+".real.jpg",
            grid_w=batch_size//8, grid_h=8)

    return sample_reconstruction 
开发者ID:Aixile,项目名称:chainer-gan-experiments,代码行数:18,代码来源:evaluation.py

示例11: preview_convert

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def preview_convert(iterator_a, iterator_b, g_a, g_b, device, gla, dst):
    @chainer.training.make_extension()
    def make_preview(trainer):
        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                x_a = iterator_a.next()
                x_a = convert.concat_examples(x_a, device)
                x_a = chainer.Variable(x_a)

                x_b = iterator_b.next()
                x_b = convert.concat_examples(x_b, device)
                x_b = chainer.Variable(x_b)

                x_ab = g_a(x_a)
                x_ba = g_b(x_b)

                x_bab = g_a(x_ba)
                x_aba = g_b(x_ab)

                preview_dir = '{}/preview'.format(dst)
                if not os.path.exists(preview_dir):
                    os.makedirs(preview_dir)
                image_dir = '{}/image'.format(dst)
                if not os.path.exists(image_dir):
                    os.makedirs(image_dir)

                names = ['a', 'ab', 'aba', 'b', 'ba', 'bab']
                images = [x_a, x_ab, x_aba, x_b, x_ba, x_bab]
                for n, i in zip(names, images):
                    i = cp.asnumpy(i.data)[:,:,padding:-padding,:].reshape(1, -1, 128)
                    image.save(image_dir+'/{}{}.jpg'.format(trainer.updater.epoch,n), i)
                    w = np.concatenate([gla.inverse(_i) for _i in dataset.reverse(i)])
                    dataset.save(preview_dir+'/{}{}.wav'.format(trainer.updater.epoch,n), 16000, w)

    return make_preview 
开发者ID:pstuvwx,项目名称:Deep_VoiceChanger,代码行数:37,代码来源:trainer.py

示例12: sample_generate_light

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def sample_generate_light(gen, mapping, dst, rows=8, cols=8, z=None, seed=0, subdir='preview'):
    @chainer.training.make_extension()
    def make_image(trainer):
        nonlocal rows, cols, z
        if trainer.updater.stage > 15:
            rows = min(rows, 2)
            cols = min(cols, 2)
        elif trainer.updater.stage > 13:
            rows = min(rows, 3)
            cols = min(cols, 3)
        elif trainer.updater.stage > 11:
            rows = min(rows, 4)
            cols = min(cols, 4)

        np.random.seed(seed)
        n_images = rows * cols
        xp = gen.xp
        if z is None:
            z = Variable(xp.asarray(mapping.make_hidden(n_images)))
        else:
            z = z[:n_images]
        with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
            x = gen(mapping(z), stage=trainer.updater.stage)
        x = chainer.cuda.to_cpu(x.data)
        np.random.seed()

        x = convert_batch_images(x, rows, cols)

        preview_dir = '{}/{}'.format(dst, subdir)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)

        preview_path = preview_dir + '/image_latest.png'
        Image.fromarray(x).save(preview_path)
        preview_path = preview_dir + '/image{:0>8}.png'.format(trainer.updater.iteration)
        Image.fromarray(x).save(preview_path)

    return make_image 
开发者ID:pfnet-research,项目名称:chainer-stylegan,代码行数:40,代码来源:train.py

示例13: parse_args

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', '-b', type=int, default=32,
                        help='Number of examples in each mini-batch')
    parser.add_argument('--bproplen', '-l', type=int, default=35,
                        help='Number of words in each mini-batch '
                             '(= length of truncated BPTT)')
    parser.add_argument('--epoch', '-e', type=int, default=20,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--gradclip', '-c', type=float, default=5,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--test', action='store_true',
                        help='Use tiny datasets for quick tests')
    parser.set_defaults(test=False)
    parser.add_argument('--hidden_size', type=int, default=300,
                        help='Number of LSTM units in each layer')
    parser.add_argument('--embed_size', type=int, default=300,
                        help='Size of embeddings')
    parser.add_argument('--model', '-m', default='model.npz',
                        help='Model file name to serialize')
    parser.add_argument('--glove', default='data/glove.6B.300d.txt',
                        help='Path to glove embedding file.')
    args = parser.parse_args()
    return args 
开发者ID:Pinafore,项目名称:qb,代码行数:32,代码来源:main.py

示例14: main

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def main(model):
    train = read_data(fold=BUZZER_TRAIN_FOLD)
    valid = read_data(fold=BUZZER_DEV_FOLD)
    print('# train data: {}'.format(len(train)))
    print('# valid data: {}'.format(len(valid)))

    train_iter = chainer.iterators.SerialIterator(train, 64)
    valid_iter = chainer.iterators.SerialIterator(valid, 64, repeat=False, shuffle=False)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(1e-4))

    updater = training.updaters.StandardUpdater(train_iter, optimizer, converter=convert_seq, device=0)
    trainer = training.Trainer(updater, (20, 'epoch'), out=model.model_dir)

    trainer.extend(extensions.Evaluator(valid_iter, model, converter=convert_seq, device=0))

    record_trigger = training.triggers.MaxValueTrigger('validation/main/accuracy', (1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model, 'buzzer.npz'), trigger=record_trigger)

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.PrintReport([
        'epoch', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
    ]))

    if not os.path.isdir(model.model_dir):
        os.mkdir(model.model_dir)

    trainer.run() 
开发者ID:Pinafore,项目名称:qb,代码行数:34,代码来源:train.py

示例15: parse_args

# 需要导入模块: import chainer [as 别名]
# 或者: from chainer import training [as 别名]
def parse_args(generators, discriminators, updaters):
    parser = argparse.ArgumentParser(description='Semantic Segmentation using Adversarial Networks')
    parser.add_argument('--generator', choices=generators.keys(), default='fcn32s',
                        help='Generator(segmentor) architecture')
    parser.add_argument('--discriminator', choices=discriminators.keys(), default='largefov',
                        help='Discriminator architecture')
    parser.add_argument('--updater', choices=updaters.keys(), default='gan',
                        help='Updater')
    parser.add_argument('--initgen_path', default='pretrained_model/vgg16.npz',
                        help='Pretrained model of generator')
    parser.add_argument('--initdis_path', default=None,
                        help='Pretrained model of discriminator')
    parser.add_argument('--batchsize', '-b', type=int, default=1,
                        help='Number of images in each mini-batch')
    parser.add_argument('--iteration', '-i', type=int, default=100000,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='snapshot',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--evaluate_interval', type=int, default=1000,
                        help='Interval of evaluation')
    parser.add_argument('--snapshot_interval', type=int, default=10000,
                        help='Interval of snapshot')
    parser.add_argument('--display_interval', type=int, default=10,
                        help='Interval of displaying log to console')
    return parser.parse_args() 
开发者ID:oyam,项目名称:Semantic-Segmentation-using-Adversarial-Networks,代码行数:31,代码来源:02-train.py


注:本文中的chainer.training方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。