当前位置: 首页>>代码示例>>Python>>正文


Python training.Trainer方法代码示例

本文整理汇总了Python中chainer.training.Trainer方法的典型用法代码示例。如果您正苦于以下问题:Python training.Trainer方法的具体用法?Python training.Trainer怎么用?Python training.Trainer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.training的用法示例。


在下文中一共展示了training.Trainer方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: create_trainer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def create_trainer(args, model):
    # Setup an optimizer
    # optimizer = chainer.optimizers.Adam()
    optimizer = chainer.optimizers.SGD()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = MyIterator(train, args.batchsize, shuffle=False)
    test_iter = MyIterator(test, args.batchsize, repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    return trainer 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:22,代码来源:gen_mnist_mlp.py

示例2: _get_mocked_trainer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def _get_mocked_trainer(links, stop_trigger=(10, 'iteration')):
    updater = mock.Mock()
    optimizer = mock.Mock()
    target = mock.Mock()
    target.namedlinks.return_value = [
        (str(i), link) for i, link in enumerate(links)]

    optimizer.target = target
    updater.get_all_optimizers.return_value = {'optimizer_name': optimizer}
    updater.iteration = 0
    updater.epoch = 0
    updater.epoch_detail = 0
    updater.is_new_epoch = True
    iter_per_epoch = 10

    def update():
        time.sleep(0.001)
        updater.iteration += 1
        updater.epoch = updater.iteration // iter_per_epoch
        updater.epoch_detail = updater.iteration / iter_per_epoch
        updater.is_new_epoch = updater.epoch == updater.epoch_detail

    updater.update = update

    return training.Trainer(updater, stop_trigger) 
开发者ID:chainer,项目名称:chainer,代码行数:27,代码来源:test_parameter_statistics.py

示例3: train_voxelnet

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def train_voxelnet():
    """Training VoxelNet."""
    config = parse_args()
    model = get_model(config["model"])
    devices = parse_devices(config['gpus'], config['updater']['name'])
    train_data, test_data = load_dataset(config["dataset"])
    train_iter, test_iter = create_iterator(train_data, test_data,
                                            config['iterator'], devices,
                                            config['updater']['name'])
    class_weight = get_class_weight(config)
    optimizer = create_optimizer(config['optimizer'], model)
    updater = create_updater(train_iter, optimizer, config['updater'], devices)
    trainer = training.Trainer(updater, config['end_trigger'], out=config['results'])
    trainer = create_extension(trainer, test_iter,  model,
                               config['extension'], devices=devices)
    trainer.run()
    chainer.serializers.save_npz(os.path.join(config['results'], 'model.npz'),
                                 model) 
开发者ID:yukitsuji,项目名称:voxelnet_chainer,代码行数:20,代码来源:train.py

示例4: train

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def train(network, loss, X_tr, Y_tr, X_te, Y_te, n_epochs=30, gamma=1):
    model= Objective(network, loss=loss, gamma=gamma)

    #optimizer = optimizers.SGD()
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    train = tuple_dataset.TupleDataset(X_tr, Y_tr)
    test = tuple_dataset.TupleDataset(X_te, Y_te)

    train_iter = iterators.SerialIterator(train, batch_size=1, shuffle=True)
    test_iter = iterators.SerialIterator(test, batch_size=1, repeat=False,
                                         shuffle=False)
    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (n_epochs, 'epoch'))

    trainer.run() 
开发者ID:mblondel,项目名称:soft-dtw,代码行数:19,代码来源:plot_chainer_MLP.py

示例5: main

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def main(model):
    train = read_data(fold=BUZZER_TRAIN_FOLD)
    valid = read_data(fold=BUZZER_DEV_FOLD)
    print('# train data: {}'.format(len(train)))
    print('# valid data: {}'.format(len(valid)))

    train_iter = chainer.iterators.SerialIterator(train, 64)
    valid_iter = chainer.iterators.SerialIterator(valid, 64, repeat=False, shuffle=False)

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(1e-4))

    updater = training.updaters.StandardUpdater(train_iter, optimizer, converter=convert_seq, device=0)
    trainer = training.Trainer(updater, (20, 'epoch'), out=model.model_dir)

    trainer.extend(extensions.Evaluator(valid_iter, model, converter=convert_seq, device=0))

    record_trigger = training.triggers.MaxValueTrigger('validation/main/accuracy', (1, 'epoch'))
    trainer.extend(extensions.snapshot_object(model, 'buzzer.npz'), trigger=record_trigger)

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.PrintReport([
        'epoch', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'elapsed_time'
    ]))

    if not os.path.isdir(model.model_dir):
        os.mkdir(model.model_dir)

    trainer.run() 
开发者ID:Pinafore,项目名称:qb,代码行数:34,代码来源:train.py

示例6: create_trainer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def create_trainer(args, model):
    # Setup an optimizer
    #optimizer = chainer.optimizers.Adam()
    optimizer = chainer.optimizers.SGD()
    optimizer.setup(model)

    # Load the datasets and mean file
    mean = np.load(args.mean)

    train = PreprocessedDataset(args.train, args.root, mean, insize)
    val = PreprocessedDataset(args.val, args.root, mean, insize, False)

    # These iterators load the images with subprocesses running in parallel to
    # the training/validation.
    train_iter = MyIterator(
        train, args.batchsize, n_processes=args.loaderjob)
    val_iter = MyIterator(
        val, args.val_batchsize, repeat=False, n_processes=args.loaderjob)

    # Set up a trainer
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(val_iter, model, device=args.gpu))
    return trainer 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:29,代码来源:gen_resnet50.py

示例7: test_linear_network

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def test_linear_network():

    # To ensure repeatability of experiments
    np.random.seed(1042)

    # Load data set
    dataset = get_dataset(True)
    iterator = LtrIterator(dataset, repeat=True, shuffle=True)
    eval_iterator = LtrIterator(dataset, repeat=False, shuffle=False)

    # Create neural network with chainer and apply our loss function
    predictor = links.Linear(None, 1)
    loss = Ranker(predictor, listnet)

    # Build optimizer, updater and trainer
    optimizer = optimizers.Adam(alpha=0.2)
    optimizer.setup(loss)
    updater = training.StandardUpdater(iterator, optimizer)
    trainer = training.Trainer(updater, (10, 'epoch'))

    # Evaluate loss before training
    before_loss = eval(loss, eval_iterator)

    # Train neural network
    trainer.run()

    # Evaluate loss after training
    after_loss = eval(loss, eval_iterator)

    # Assert precomputed values
    assert_almost_equal(before_loss, 0.26958397)
    assert_almost_equal(after_loss, 0.2326711) 
开发者ID:rjagerman,项目名称:shoelace,代码行数:34,代码来源:test_linear_network.py

示例8: _get_mock_trainer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def _get_mock_trainer(self, out_path, trigger=None, updater=None):
        class _MockTrainer(Trainer):
            def __init__(
                self, out_path, stop_trigger=IntervalTrigger(100, 'epoch'),
                    updater=None):

                self.out = out_path
                self.stop_trigger = stop_trigger

                hyperparam = Hyperparameter()
                hyperparam.lr = 0.005
                optimizer = MagicMock()
                optimizer.__class__.__name__ = 'MomentumSGD'
                optimizer.hyperparam = hyperparam

                if updater is None:
                    updater = MagicMock()
                    updater.epoch = 0
                    updater.iteration = 0
                    updater.get_optimizer.return_value = optimizer
                self.updater = updater

            @property
            def elapsed_time(self):
                return 0

            def serialize(self, serializer):
                pass

        return _MockTrainer(out_path, trigger, updater) 
开发者ID:chainer,项目名称:chainerui,代码行数:32,代码来源:test_commands_extension.py

示例9: _run_test

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def _run_test(self, tempdir, initial_flag):
        n_data = 4
        n_epochs = 3
        outdir = os.path.join(tempdir, 'testresult')

        # Prepare
        model = Model()
        classifier = links.Classifier(model)
        optimizer = chainer.optimizers.Adam()
        optimizer.setup(classifier)

        dataset = Dataset([i for i in range(n_data)])
        iterator = chainer.iterators.SerialIterator(dataset, 1, shuffle=False)
        updater = training.updaters.StandardUpdater(iterator, optimizer)
        trainer = training.Trainer(updater, (n_epochs, 'epoch'), out=outdir)

        extension = c.DumpGraph('main/loss', filename='test.dot')
        trainer.extend(extension)

        # Run
        with chainer.using_config('keep_graph_on_report', initial_flag):
            trainer.run()

        # Check flag history
        self.assertEqual(model.flag_history,
                         [True] + [initial_flag] * (n_data * n_epochs - 1))

        # Check the dumped graph
        graph_path = os.path.join(outdir, 'test.dot')
        with open(graph_path) as f:
            graph_dot = f.read()

        # Check that only the first iteration is dumped
        self.assertIn('Function1', graph_dot)
        self.assertNotIn('Function2', graph_dot)

        if c.is_graphviz_available():
            self.assertTrue(os.path.exists(os.path.join(outdir, 'test.png'))) 
开发者ID:chainer,项目名称:chainer,代码行数:40,代码来源:test_computational_graph.py

示例10: prepare

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def prepare(self, dirname='test', device=None):
        outdir = os.path.join(self.temp_dir, dirname)
        self.updater = training.updaters.StandardUpdater(
            self.iterator, self.optimizer, device=device)
        self.trainer = training.Trainer(
            self.updater, (self.n_epochs, 'epoch'), out=outdir)
        self.trainer.extend(training.extensions.FailOnNonNumber()) 
开发者ID:chainer,项目名称:chainer,代码行数:9,代码来源:test_fail_on_nonnumber.py

示例11: _prepare_multinode_snapshot

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def _prepare_multinode_snapshot(n, result):
    n_units = 100
    batchsize = 10
    comm = create_communicator('naive')
    model = L.Classifier(MLP(n_units, 10))
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.Adam(), comm)
    optimizer.setup(model)

    if comm.rank == 0:
        train, _ = chainer.datasets.get_mnist()
    else:
        train, _ = None, None

    train = chainermn.scatter_dataset(train, comm, shuffle=True)
    train_iter = chainer.iterators.SerialIterator(train, batchsize)

    updater = StandardUpdater(train_iter, optimizer)
    trainer = Trainer(updater, out=result)

    snapshot = extensions.snapshot(target=updater, autoload=True)
    replica_sets = []
    mn_snapshot = multi_node_snapshot(comm, snapshot, replica_sets)
    mn_snapshot.initialize(trainer)
    for _ in range(n):
        updater.update()

    return updater, mn_snapshot, trainer 
开发者ID:chainer,项目名称:chainer,代码行数:30,代码来源:test_multi_node_snapshot.py

示例12: train

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def train(args):
    nz = args.nz
    batch_size = args.batch_size
    epochs = args.epochs
    gpu = args.gpu

    # CIFAR-10 images in range [-1, 1] (tanh generator outputs)
    train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
    train -= 1.0
    train_iter = iterators.SerialIterator(train, batch_size)

    z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
                                 batch_size)

    optimizer_generator = optimizers.RMSprop(lr=0.00005)
    optimizer_critic = optimizers.RMSprop(lr=0.00005)
    optimizer_generator.setup(Generator())
    optimizer_critic.setup(Critic())

    updater = WassersteinGANUpdater(
        iterator=train_iter,
        noise_iterator=z_iter,
        optimizer_generator=optimizer_generator,
        optimizer_critic=optimizer_critic,
        device=gpu)

    trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
    trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
            'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
    trainer.run() 
开发者ID:hvy,项目名称:chainer-wasserstein-gan,代码行数:35,代码来源:train.py

示例13: _train_trainer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def _train_trainer(self, examples):
        """Training with chainer trainer module"""
        train_iter = SerialIterator(examples, args.batch_size)
        optimizer = optimizers.Adam(alpha=args.lr)
        optimizer.setup(self.nnet)

        def loss_func(boards, target_pis, target_vs):
            out_pi, out_v = self.nnet(boards)
            l_pi = self.loss_pi(target_pis, out_pi)
            l_v = self.loss_v(target_vs, out_v)
            total_loss = l_pi + l_v
            chainer.reporter.report({
                'loss': total_loss,
                'loss_pi': l_pi,
                'loss_v': l_v,
            }, observer=self.nnet)
            return total_loss

        updater = training.StandardUpdater(
            train_iter, optimizer, device=args.device, loss_func=loss_func, converter=converter)
        # Set up the trainer.
        trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.out)
        # trainer.extend(extensions.snapshot(), trigger=(args.epochs, 'epoch'))
        trainer.extend(extensions.LogReport())
        trainer.extend(extensions.PrintReport([
            'epoch', 'main/loss', 'main/loss_pi', 'main/loss_v', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=10))
        trainer.run() 
开发者ID:suragnair,项目名称:alpha-zero-general,代码行数:30,代码来源:NNet.py

示例14: main

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def main():
  # Training settings
  args = get_args()

  # Set up a neural network to train
  model = L.Classifier(Net())

  if args.gpu >= 0:
    # Make a specified GPU current
    chainer.backends.cuda.get_device_from_id(args.gpu).use()
    model.to_gpu() # Copy the model to the GPU

  # Setup an optimizer
  optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=args.momentum)
  optimizer.setup(model)

  # Load the MNIST dataset
  train, test = chainer.datasets.get_mnist(ndim=3)
  train_iter = chainer.iterators.SerialIterator(train, args.batch_size)
  test_iter = chainer.iterators.SerialIterator(test, args.test_batch_size,
                                               repeat=False, shuffle=False)

  # Set up a trainer
  updater = training.updaters.StandardUpdater(
      train_iter, optimizer, device=args.gpu)
  trainer = training.Trainer(updater, (args.epochs, 'epoch'))

  # Evaluate the model with the test dataset for each epoch
  trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

  # Write a log of evaluation statistics for each epoch
  trainer.extend(extensions.LogReport())

  # Print selected entries of the log to stdout
  trainer.extend(extensions.PrintReport(
      ['epoch', 'main/loss', 'validation/main/loss',
       'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

  # Send selected entries of the log to CMLE HP tuning system
  trainer.extend(
    HpReport(hp_metric_val='validation/main/loss', hp_metric_tag='my_loss'))

  if args.resume:
    # Resume from a snapshot
    tmp_model_file = os.path.join('/tmp', MODEL_FILE_NAME)
    if not os.path.exists(tmp_model_file):
      subprocess.check_call([
        'gsutil', 'cp', os.path.join(args.model_dir, MODEL_FILE_NAME),
        tmp_model_file])
    if os.path.exists(tmp_model_file):
      chainer.serializers.load_npz(tmp_model_file, trainer)
  
  trainer.run()

  if args.model_dir:
    tmp_model_file = os.path.join('/tmp', MODEL_FILE_NAME)
    serializers.save_npz(tmp_model_file, model)
    subprocess.check_call([
        'gsutil', 'cp', tmp_model_file,
        os.path.join(args.model_dir, MODEL_FILE_NAME)]) 
开发者ID:GoogleCloudPlatform,项目名称:cloudml-samples,代码行数:62,代码来源:mnist.py

示例15: prepare_trainer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import Trainer [as 别名]
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_data,
                    val_data,
                    logging_dir_path,
                    use_gpus):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception("Unsupported optimizer: {}".format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if use_gpus else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_data["iterator"],
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, "epoch"),
        out=logging_dir_path)

    val_interval = 100000, "iteration"
    log_interval = 1000, "iteration"

    trainer.extend(
        extension=extensions.Evaluator(
            iterator=val_data["iterator"],
            target=net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph("main/loss"))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            "model_iter_{.updater.iteration}"),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            "epoch", "iteration", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy",
            "lr"]),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
开发者ID:osmr,项目名称:imgclsmob,代码行数:57,代码来源:train_ch.py


注:本文中的chainer.training.Trainer方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。