本文整理汇总了Python中chainer.training.extensions.dump_graph方法的典型用法代码示例。如果您正苦于以下问题:Python extensions.dump_graph方法的具体用法?Python extensions.dump_graph怎么用?Python extensions.dump_graph使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer.training.extensions
的用法示例。
在下文中一共展示了extensions.dump_graph方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: run_training
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def run_training(args, model):
trainer = create_trainer(args, model)
# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.dump_graph('main/loss'))
# Take a snapshot for each specified epoch
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())
# Save two plot images to the result dir
if args.plot and extensions.PlotReport.available():
trainer.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
# Print a progress bar to stdout
trainer.extend(extensions.ProgressBar())
if args.resume:
# Resume from a snapshot
chainer.serializers.load_npz(args.resume, trainer)
# Run the training
trainer.run()
示例2: prepare_trainer
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def prepare_trainer(net,
optimizer_name,
lr,
momentum,
num_epochs,
train_data,
val_data,
logging_dir_path,
use_gpus):
if optimizer_name == "sgd":
optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
elif optimizer_name == "nag":
optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
else:
raise Exception("Unsupported optimizer: {}".format(optimizer_name))
optimizer.setup(net)
# devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
devices = (0,) if use_gpus else (-1,)
updater = training.updaters.StandardUpdater(
iterator=train_data["iterator"],
optimizer=optimizer,
device=devices[0])
trainer = training.Trainer(
updater=updater,
stop_trigger=(num_epochs, "epoch"),
out=logging_dir_path)
val_interval = 100000, "iteration"
log_interval = 1000, "iteration"
trainer.extend(
extension=extensions.Evaluator(
iterator=val_data["iterator"],
target=net,
device=devices[0]),
trigger=val_interval)
trainer.extend(extensions.dump_graph("main/loss"))
trainer.extend(extensions.snapshot(), trigger=val_interval)
trainer.extend(
extensions.snapshot_object(
net,
"model_iter_{.updater.iteration}"),
trigger=val_interval)
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(
extensions.PrintReport([
"epoch", "iteration", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy",
"lr"]),
trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
return trainer
示例3: prepare_trainer
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def prepare_trainer(net,
optimizer_name,
lr,
momentum,
num_epochs,
train_iter,
val_iter,
logging_dir_path,
num_gpus=0):
if optimizer_name == "sgd":
optimizer = chainer.optimizers.MomentumSGD(lr=lr, momentum=momentum)
elif optimizer_name == "nag":
optimizer = chainer.optimizers.NesterovAG(lr=lr, momentum=momentum)
else:
raise Exception('Unsupported optimizer: {}'.format(optimizer_name))
optimizer.setup(net)
# devices = tuple(range(num_gpus)) if num_gpus > 0 else (-1, )
devices = (0,) if num_gpus > 0 else (-1,)
updater = training.updaters.StandardUpdater(
iterator=train_iter,
optimizer=optimizer,
device=devices[0])
trainer = training.Trainer(
updater=updater,
stop_trigger=(num_epochs, 'epoch'),
out=logging_dir_path)
val_interval = 100000, 'iteration'
log_interval = 1000, 'iteration'
trainer.extend(
extension=extensions.Evaluator(
val_iter,
net,
device=devices[0]),
trigger=val_interval)
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=val_interval)
trainer.extend(
extensions.snapshot_object(
net,
'model_iter_{.updater.iteration}'),
trigger=val_interval)
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(
extensions.PrintReport([
'epoch', 'iteration', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy',
'lr']),
trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))
return trainer
示例4: get_trainer
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def get_trainer(net, updater, log_dir, print_fields, curriculum=None, extra_extensions=(), epochs=10, snapshot_interval=20000, print_interval=100, postprocess=None, do_logging=True, model_files=()):
if curriculum is None:
trainer = chainer.training.Trainer(
updater,
(epochs, 'epoch'),
out=log_dir,
)
else:
trainer = chainer.training.Trainer(
updater,
EarlyStopIntervalTrigger(epochs, 'epoch', curriculum),
out=log_dir,
)
# dump computational graph
trainer.extend(extensions.dump_graph('main/loss'))
# also observe learning rate
observe_lr_extension = chainer.training.extensions.observe_lr()
observe_lr_extension.trigger = (print_interval, 'iteration')
trainer.extend(observe_lr_extension)
# Take snapshots
trainer.extend(
extensions.snapshot(filename="trainer_snapshot"),
trigger=lambda trainer:
trainer.updater.is_new_epoch or
(trainer.updater.iteration > 0 and trainer.updater.iteration % snapshot_interval == 0)
)
if do_logging:
# write all statistics to a file
trainer.extend(Logger(model_files, log_dir, keys=print_fields, trigger=(print_interval, 'iteration'), postprocess=postprocess))
# print some interesting statistics
trainer.extend(extensions.PrintReport(
print_fields,
log_report='Logger',
))
# Progressbar!!
trainer.extend(extensions.ProgressBar(update_interval=1))
for extra_extension, trigger in extra_extensions:
trainer.extend(extra_extension, trigger=trigger)
return trainer
示例5: train_one_epoch
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def train_one_epoch(model, train_data, lr, gpu, batchsize, out):
train_model = PixelwiseSoftmaxClassifier(model)
if gpu >= 0:
# Make a specified GPU current
chainer.cuda.get_device_from_id(gpu).use()
train_model.to_gpu() # Copy the model to the GPU
log_trigger = (0.1, 'epoch')
validation_trigger = (1, 'epoch')
end_trigger = (1, 'epoch')
train_data = TransformDataset(
train_data, ('img', 'label_map'), SimpleDoesItTransform(model.mean))
val = VOCSemanticSegmentationWithBboxDataset(
split='val').slice[:, ['img', 'label_map']]
# Iterator
train_iter = iterators.MultiprocessIterator(train_data, batchsize)
val_iter = iterators.MultiprocessIterator(
val, 1, shuffle=False, repeat=False, shared_mem=100000000)
# Optimizer
optimizer = optimizers.MomentumSGD(lr=lr, momentum=0.9)
optimizer.setup(train_model)
optimizer.add_hook(chainer.optimizer_hooks.WeightDecay(rate=0.0001))
# Updater
updater = training.updaters.StandardUpdater(
train_iter, optimizer, device=gpu)
# Trainer
trainer = training.Trainer(updater, end_trigger, out=out)
trainer.extend(extensions.LogReport(trigger=log_trigger))
trainer.extend(extensions.observe_lr(), trigger=log_trigger)
trainer.extend(extensions.dump_graph('main/loss'))
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(
['main/loss'], x_key='iteration',
file_name='loss.png'))
trainer.extend(extensions.PlotReport(
['validation/main/miou'], x_key='iteration',
file_name='miou.png'))
trainer.extend(extensions.snapshot_object(
model, filename='snapshot.npy'),
trigger=end_trigger)
trainer.extend(extensions.PrintReport(
['epoch', 'iteration', 'elapsed_time', 'lr',
'main/loss', 'validation/main/miou',
'validation/main/mean_class_accuracy',
'validation/main/pixel_accuracy']),
trigger=log_trigger)
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.extend(
SemanticSegmentationEvaluator(
val_iter, model,
voc_semantic_segmentation_label_names),
trigger=validation_trigger)
trainer.run()
示例6: register_extensions
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def register_extensions(trainer, model, test_iter, args):
if args.mode.startswith('seg'):
# Max accuracy
best_trigger = training.triggers.BestValueTrigger(
'validation/main/accuracy', lambda a, b: a < b, (1, 'epoch'))
elif args.mode.startswith('mat'):
# Min loss
best_trigger = training.triggers.BestValueTrigger(
'validation/main/loss', lambda a, b: a > b, (1, 'epoch'))
else:
logger.error('Invalid training mode')
# Segmentation extensions
trainer.extend(
custom_extensions.PortraitVisEvaluator(
test_iter, model, device=args.gpus[0],
converter=select_converter(args.mode),
filename='vis_epoch={epoch}_idx={index}.jpg',
mode=args.mode
), trigger=(1, 'epoch'))
# Basic extensions
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.LogReport(trigger=(200, 'iteration')))
trainer.extend(extensions.ProgressBar(update_interval=20))
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
'validation/main/accuracy', 'lr', 'elapsed_time']))
trainer.extend(extensions.observe_lr(), trigger=(200, 'iteration'))
# Snapshots
trainer.extend(extensions.snapshot(
filename='snapshot_epoch_{.updater.epoch}'
), trigger=(5, 'epoch'))
trainer.extend(extensions.snapshot_object(
model, filename='model_best'
), trigger=best_trigger)
# ChainerUI extensions
trainer.extend(chainerui.extensions.CommandsExtension())
chainerui.utils.save_args(args, args.out)
# Plotting extensions
if extensions.PlotReport.available():
trainer.extend(
extensions.PlotReport(
['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))
示例7: create_trainer
# 需要导入模块: from chainer.training import extensions [as 别名]
# 或者: from chainer.training.extensions import dump_graph [as 别名]
def create_trainer(
config,
project_path,
updater,
model,
eval_func,
iterator_test,
iterator_train_varidation,
loss_names,
converter=chainer.dataset.convert.concat_examples,
):
# type: (TrainConfig, str, any, typing.Dict, any, any, any, any, any) -> any
def _make_evaluator(iterator):
return utility.chainer_utility.NoVariableEvaluator(
iterator,
target=model,
converter=converter,
eval_func=eval_func,
device=config.gpu,
)
trainer = chainer.training.Trainer(updater, out=project_path)
log_trigger = (config.log_iteration, 'iteration')
save_trigger = (config.save_iteration, 'iteration')
eval_test_name = 'eval/test'
eval_train_name = 'eval/train'
snapshot = extensions.snapshot_object(model['main'], '{.updater.iteration}.model')
trainer.extend(snapshot, trigger=save_trigger)
trainer.extend(extensions.dump_graph('main/' + loss_names[0], out_name='main.dot'))
trainer.extend(_make_evaluator(iterator_test), name=eval_test_name, trigger=log_trigger)
trainer.extend(_make_evaluator(iterator_train_varidation), name=eval_train_name, trigger=log_trigger)
report_target = []
for evaluator_name in ['', eval_test_name + '/', eval_train_name + '/']:
for model_name in ['main/']:
for loss_name in loss_names:
report_target.append(evaluator_name + model_name + loss_name)
trainer.extend(extensions.LogReport(trigger=log_trigger, log_name="log.txt"))
trainer.extend(extensions.PrintReport(report_target))
return trainer