当前位置: 首页>>代码示例>>Python>>正文


Python extensions.dump_graph方法代码示例

本文整理汇总了Python中chainer.training.extensions.dump_graph方法的典型用法代码示例。如果您正苦于以下问题:Python extensions.dump_graph方法的具体用法?Python extensions.dump_graph怎么用?Python extensions.dump_graph使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.training.extensions的用法示例。


在下文中一共展示了extensions.dump_graph方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: run_training

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def run_training(args, model):
    trainer = create_trainer(args, model)

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if args.plot and extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run() 
开发者ID:pfnet-research,项目名称:chainer-compiler,代码行数:44,代码来源:gen_resnet50.py

示例2: prepare_trainer

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_data,
                    val_data,
                    logging_dir_path,
                    use_gpus):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception("Unsupported optimizer: {}".format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if use_gpus else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_data["iterator"],
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, "epoch"),
        out=logging_dir_path)

    val_interval = 100000, "iteration"
    log_interval = 1000, "iteration"

    trainer.extend(
        extension=extensions.Evaluator(
            iterator=val_data["iterator"],
            target=net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph("main/loss"))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            "model_iter_{.updater.iteration}"),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            "epoch", "iteration", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy",
            "lr"]),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
开发者ID:osmr,项目名称:imgclsmob,代码行数:57,代码来源:train_ch.py

示例3: prepare_trainer

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def prepare_trainer(net,
                    optimizer_name,
                    lr,
                    momentum,
                    num_epochs,
                    train_iter,
                    val_iter,
                    logging_dir_path,
                    num_gpus=0):
    if optimizer_name == "sgd":
        optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
    elif optimizer_name == "nag":
        optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
    else:
        raise Exception('Unsupported optimizer: {}'.format(optimizer_name))
    optimizer.setup(net)

    # devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
    devices = (0,) if num_gpus > 0 else (-1,)

    updater = training.updaters.StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        device=devices[0])
    trainer = training.Trainer(
        updater=updater,
        stop_trigger=(num_epochs, 'epoch'),
        out=logging_dir_path)

    val_interval = 100000, 'iteration'
    log_interval = 1000, 'iteration'

    trainer.extend(
        extension=extensions.Evaluator(
            val_iter,
            net,
            device=devices[0]),
        trigger=val_interval)
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(), trigger=val_interval)
    trainer.extend(
        extensions.snapshot_object(
            net,
            'model_iter_{.updater.iteration}'),
        trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy',
            'lr']),
        trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    return trainer 
开发者ID:osmr,项目名称:imgclsmob,代码行数:57,代码来源:train_ch_cifar.py

示例4: get_trainer

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def get_trainer(net, updater, log_dir, print_fields, curriculum=None, extra_extensions=(), epochs=10, snapshot_interval=20000, print_interval=100, postprocess=None, do_logging=True, model_files=()):
    if curriculum is None:
        trainer = chainer.training.Trainer(
            updater,
            (epochs, 'epoch'),
            out=log_dir,
        )
    else:
        trainer = chainer.training.Trainer(
            updater,
            EarlyStopIntervalTrigger(epochs, 'epoch', curriculum),
            out=log_dir,
        )

    # dump computational graph
    trainer.extend(extensions.dump_graph('main/loss'))

    # also observe learning rate
    observe_lr_extension = chainer.training.extensions.observe_lr()
    observe_lr_extension.trigger = (print_interval, 'iteration')
    trainer.extend(observe_lr_extension)

    # Take snapshots
    trainer.extend(
        extensions.snapshot(filename="trainer_snapshot"),
        trigger=lambda trainer:
        trainer.updater.is_new_epoch or
        (trainer.updater.iteration > 0 and trainer.updater.iteration % snapshot_interval == 0)
    )

    if do_logging:
        # write all statistics to a file
        trainer.extend(Logger(model_files, log_dir, keys=print_fields, trigger=(print_interval, 'iteration'), postprocess=postprocess))

        # print some interesting statistics
        trainer.extend(extensions.PrintReport(
            print_fields,
            log_report='Logger',
        ))

    # Progressbar!!
    trainer.extend(extensions.ProgressBar(update_interval=1))

    for extra_extension, trigger in extra_extensions:
        trainer.extend(extra_extension, trigger=trigger)

    return trainer 
开发者ID:Bartzi,项目名称:see,代码行数:49,代码来源:train_utils.py

示例5: train_one_epoch

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def train_one_epoch(model, train_data, lr, gpu, batchsize, out):
    train_model = PixelwiseSoftmaxClassifier(model)
    if gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu).use()
        train_model.to_gpu()  # Copy the model to the GPU
    log_trigger = (0.1, 'epoch')
    validation_trigger = (1, 'epoch')
    end_trigger = (1, 'epoch')

    train_data = TransformDataset(
        train_data, ('img', 'label_map'), SimpleDoesItTransform(model.mean))
    val = VOCSemanticSegmentationWithBboxDataset(
        split='val').slice[:, ['img', 'label_map']]

    # Iterator
    train_iter = iterators.MultiprocessIterator(train_data, batchsize)
    val_iter = iterators.MultiprocessIterator(
        val, 1, shuffle=False, repeat=False, shared_mem=100000000)

    # Optimizer
    optimizer = optimizers.MomentumSGD(lr=lr, momentum=0.9)
    optimizer.setup(train_model)
    optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(rate=0.0001))

    # Updater
    updater = training.updaters.StandardUpdater(
        train_iter, optimizer, device=gpu)

    # Trainer
    trainer = training.Trainer(updater, end_trigger, out=out)

    trainer.extend(extensions.LogReport(trigger=log_trigger))
    trainer.extend(extensions.observe_lr(), trigger=log_trigger)
    trainer.extend(extensions.dump_graph('main/loss'))

    if extensions.PlotReport.available():
        trainer.extend(extensions.PlotReport(
            ['main/loss'], x_key='iteration',
            file_name='loss.png'))
        trainer.extend(extensions.PlotReport(
            ['validation/main/miou'], x_key='iteration',
            file_name='miou.png'))

    trainer.extend(extensions.snapshot_object(
        model, filename='snapshot.npy'),
        trigger=end_trigger)
    trainer.extend(extensions.PrintReport(
        ['epoch', 'iteration', 'elapsed_time', 'lr',
         'main/loss', 'validation/main/miou',
         'validation/main/mean_class_accuracy',
         'validation/main/pixel_accuracy']),
        trigger=log_trigger)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.extend(
        SemanticSegmentationEvaluator(
            val_iter, model,
            voc_semantic_segmentation_label_names),
        trigger=validation_trigger)
    trainer.run() 
开发者ID:chainer,项目名称:models,代码行数:63,代码来源:train.py

示例6: register_extensions

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def register_extensions(trainer, model, test_iter, args):
    if args.mode.startswith('seg'):
        # Max accuracy
        best_trigger = training.triggers.BestValueTrigger(
            'validation/main/accuracy', lambda a, b: a < b, (1, 'epoch'))
    elif args.mode.startswith('mat'):
        # Min loss
        best_trigger = training.triggers.BestValueTrigger(
            'validation/main/loss', lambda a, b: a > b, (1, 'epoch'))
    else:
        logger.error('Invalid training mode')

    # Segmentation extensions
    trainer.extend(
        custom_extensions.PortraitVisEvaluator(
            test_iter, model, device=args.gpus[0],
            converter=select_converter(args.mode),
            filename='vis_epoch={epoch}_idx={index}.jpg',
            mode=args.mode
        ), trigger=(1, 'epoch'))

    # Basic extensions
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.LogReport(trigger=(200, 'iteration')))
    trainer.extend(extensions.ProgressBar(update_interval=20))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
         'validation/main/accuracy', 'lr', 'elapsed_time']))
    trainer.extend(extensions.observe_lr(), trigger=(200, 'iteration'))

    # Snapshots
    trainer.extend(extensions.snapshot(
        filename='snapshot_epoch_{.updater.epoch}'
    ), trigger=(5, 'epoch'))
    trainer.extend(extensions.snapshot_object(
        model, filename='model_best'
    ), trigger=best_trigger)

    # ChainerUI extensions
    trainer.extend(chainerui.extensions.CommandsExtension())
    chainerui.utils.save_args(args, args.out)

    # Plotting extensions
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(
                ['main/loss', 'validation/main/loss'],
                'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png')) 
开发者ID:takiyu,项目名称:portrait_matting,代码行数:54,代码来源:train.py

示例7: create_trainer

# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def create_trainer(
        config,
        project_path,
        updater,
        model,
        eval_func,
        iterator_test,
        iterator_train_varidation,
        loss_names,
        converter=chainer.dataset.convert.concat_examples,
):
    # type: (TrainConfig, str, any, typing.Dict, any, any, any, any, any) -> any
    def _make_evaluator(iterator):
        return utility.chainer_utility.NoVariableEvaluator(
            iterator,
            target=model,
            converter=converter,
            eval_func=eval_func,
            device=config.gpu,
        )

    trainer = chainer.training.Trainer(updater, out=project_path)

    log_trigger = (config.log_iteration, 'iteration')
    save_trigger = (config.save_iteration, 'iteration')

    eval_test_name = 'eval/test'
    eval_train_name = 'eval/train'

    snapshot = extensions.snapshot_object(model['main'], '{.updater.iteration}.model')
    trainer.extend(snapshot, trigger=save_trigger)

    trainer.extend(extensions.dump_graph('main/' + loss_names[0], out_name='main.dot'))

    trainer.extend(_make_evaluator(iterator_test), name=eval_test_name, trigger=log_trigger)
    trainer.extend(_make_evaluator(iterator_train_varidation), name=eval_train_name, trigger=log_trigger)

    report_target = []
    for evaluator_name in ['', eval_test_name + '/', eval_train_name + '/']:
        for model_name in ['main/']:
            for loss_name in loss_names:
                report_target.append(evaluator_name + model_name + loss_name)

    trainer.extend(extensions.LogReport(trigger=log_trigger, log_name="log.txt"))
    trainer.extend(extensions.PrintReport(report_target))

    return trainer 
开发者ID:DwangoMediaVillage,项目名称:Comicolorization,代码行数:49,代码来源:trainer.py


注:本文中的chainer.training.extensions.dump_graph方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。