本文整理汇总了Python中chainer.training.make_extension方法的典型用法代码示例。如果您正苦于以下问题:Python training.make_extension方法的具体用法?Python training.make_extension怎么用?Python training.make_extension使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类chainer.training
的用法示例。
在下文中一共展示了training.make_extension方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: get_model
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def get_model(args, loss_func, vocab, vocab_ngram_tokens, current_utils=utils.word):
model = None
if args.subword == 'none':
if args.model == 'skipgram':
model = current_utils.SkipGram(vocab.cnt_words, args.dimensions, loss_func)
if args.model == 'cbow':
# todo only skipgram supported
model = current_utils.ContinuousBoW(vocab.cnt_words, args.dimensions, loss_func)
else:
if args.model == 'skipgram':
model = utils.subword.SkipGram(args.subword, vocab, vocab_ngram_tokens, args.dimensions, loss_func, )
if model is None:
raise Exception('Unknown model and word/subword type: {} "and" {}'.format(args.model, args.subword))
return model
#@training.make_extension(trigger=(1, 'epoch'))
#def dump_embs(trainer):
# print("dumping embeddings")
示例2: test_on_error
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_on_error(self):
class TheOnlyError(Exception):
pass
@training.make_extension(trigger=(1, 'iteration'), priority=100)
def exception_raiser(trainer):
raise TheOnlyError()
self.trainer.extend(exception_raiser)
snapshot = extensions.snapshot_object(self.trainer, self.filename,
snapshot_on_error=True)
self.trainer.extend(snapshot)
self.assertFalse(os.path.exists(self.filename))
with self.assertRaises(TheOnlyError):
self.trainer.run()
self.assertTrue(os.path.exists(self.filename))
示例3: test_exception_in_exception_handler
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_exception_in_exception_handler(self):
ext = ErrorHandlingExtension()
self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1)
self.assertFalse(ext.is_error_handled)
def exception_handler(trainer, exp, tb):
raise ValueError('hogehoge from exception handler')
@training.make_extension(trigger=(1, 'iteration'), priority=100,
on_error=exception_handler)
def exception_raiser(trainer):
raise TheOnlyError()
self.trainer.extend(exception_raiser)
dummy_extension = DummyExtension(self)
self.trainer.extend(dummy_extension)
with self.assertRaises(TheOnlyError):
self.trainer.run()
self.assertTrue(ext.is_error_handled)
self.assertTrue(dummy_extension.is_finalized)
示例4: adadelta_eps_decay
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def adadelta_eps_decay(eps_decay):
"""Extension to perform adadelta eps decay.
Args:
eps_decay (float): Decay rate of eps.
Returns:
An extension function.
"""
@training.make_extension(trigger=(1, "epoch"))
def adadelta_eps_decay(trainer):
_adadelta_eps_decay(trainer, eps_decay)
return adadelta_eps_decay
示例5: adam_lr_decay
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def adam_lr_decay(eps_decay):
"""Extension to perform adam lr decay.
Args:
eps_decay (float): Decay rate of lr.
Returns:
An extension function.
"""
@training.make_extension(trigger=(1, "epoch"))
def adam_lr_decay(trainer):
_adam_lr_decay(trainer, eps_decay)
return adam_lr_decay
示例6: snapshot_object
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def snapshot_object(target, filename):
"""Returns a trainer extension to take snapshots of a given object.
Args:
target (model): Object to serialize.
filename (str): Name of the file into which the object is serialized.It can
be a format string, where the trainer object is passed to
the :meth: `str.format` method. For example,
``'snapshot_{.updater.iteration}'`` is converted to
``'snapshot_10000'`` at the 10,000th iteration.
Returns:
An extension function.
"""
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
def snapshot_object(trainer):
torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
return snapshot_object
示例7: test_add_make_extension
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_make_extension(self):
self.is_called = False
@training.make_extension()
def dummy_extension(trainer):
self.is_called = True
self.trainer.extend(dummy_extension)
self.trainer.run()
self.assertTrue(self.is_called)
示例8: test_add_make_extension_with_initializer
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_make_extension_with_initializer(self):
self.is_called = False
def initializer(trainer):
trainer.is_initialized = True
@training.make_extension(initializer=initializer)
def dummy_extension(trainer):
self.assertTrue(trainer.is_initialized)
self.is_called = True
self.trainer.extend(dummy_extension)
self.trainer.run()
self.assertTrue(self.is_called)
示例9: test_add_two_extensions_default_priority
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_two_extensions_default_priority(self):
self.called_order = []
@training.make_extension(trigger=(1, 'epoch'))
def dummy_extension_1(trainer):
self.called_order.append(1)
@training.make_extension(trigger=(1, 'epoch'))
def dummy_extension_2(trainer):
self.called_order.append(2)
self.trainer.extend(dummy_extension_1)
self.trainer.extend(dummy_extension_2)
self.trainer.run()
self.assertEqual(self.called_order, [1, 2])
示例10: test_add_two_extensions_specific_priority
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_two_extensions_specific_priority(self):
self.called_order = []
@training.make_extension(trigger=(1, 'epoch'), priority=50)
def dummy_extension_1(trainer):
self.called_order.append(1)
@training.make_extension(trigger=(1, 'epoch'), priority=100)
def dummy_extension_2(trainer):
self.called_order.append(2)
self.trainer.extend(dummy_extension_1)
self.trainer.extend(dummy_extension_2)
self.trainer.run()
self.assertEqual(self.called_order, [2, 1])
示例11: test_make_extension
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension(self):
def initialize(trainer):
pass
@training.make_extension(trigger=(2, 'epoch'), default_name='my_ext',
priority=50, initializer=initialize)
def my_extension(trainer):
pass
self.assertEqual(my_extension.trigger, (2, 'epoch'))
self.assertEqual(my_extension.default_name, 'my_ext')
self.assertEqual(my_extension.priority, 50)
self.assertIs(my_extension.initialize, initialize)
示例12: test_make_extension_default_values
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension_default_values(self):
@training.make_extension()
def my_extension(trainer):
pass
self.assertEqual(my_extension.trigger, (1, 'iteration'))
self.assertEqual(my_extension.default_name, 'my_extension')
self.assertEqual(my_extension.priority, training.PRIORITY_READER)
self.assertIsNone(my_extension.initialize)
示例13: test_make_extension_deleted_argument
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension_deleted_argument(self):
with self.assertRaises(ValueError):
@training.make_extension(invoke_before_training=False)
def my_extension(_):
pass
示例14: test_make_extension_unexpected_kwargs
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension_unexpected_kwargs(self):
with self.assertRaises(TypeError):
@training.make_extension(foo=1)
def my_extension(_):
pass
示例15: restore_snapshot
# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz):
"""Extension to restore snapshot.
Returns:
An extension function.
"""
@training.make_extension(trigger=(1, "epoch"))
def restore_snapshot(trainer):
_restore_snapshot(model, snapshot, load_fn)
return restore_snapshot