当前位置: 首页>>代码示例>>Python>>正文


Python training.make_extension方法代码示例

本文整理汇总了Python中chainer.training.make_extension方法的典型用法代码示例。如果您正苦于以下问题:Python training.make_extension方法的具体用法?Python training.make_extension怎么用?Python training.make_extension使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在chainer.training的用法示例。


在下文中一共展示了training.make_extension方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: get_model

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def get_model(args, loss_func, vocab, vocab_ngram_tokens, current_utils=utils.word):
    model = None
    if args.subword == 'none':
        if args.model == 'skipgram':
            model = current_utils.SkipGram(vocab.cnt_words, args.dimensions, loss_func)
        if args.model == 'cbow':
            # todo only skipgram supported
            model = current_utils.ContinuousBoW(vocab.cnt_words, args.dimensions, loss_func)
    else:
        if args.model == 'skipgram':
            model = utils.subword.SkipGram(args.subword, vocab, vocab_ngram_tokens, args.dimensions, loss_func, )

    if model is None:
        raise Exception('Unknown model and word/subword type: {} "and" {}'.format(args.model, args.subword))
    return model


#@training.make_extension(trigger=(1, 'epoch'))
#def dump_embs(trainer):
#    print("dumping embeddings") 
开发者ID:vecto-ai,项目名称:vecto,代码行数:22,代码来源:train_word2vec.py

示例2: test_on_error

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_on_error(self):

        class TheOnlyError(Exception):
            pass

        @training.make_extension(trigger=(1, 'iteration'), priority=100)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        snapshot = extensions.snapshot_object(self.trainer, self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename)) 
开发者ID:chainer,项目名称:chainer,代码行数:22,代码来源:test_snapshot.py

示例3: test_exception_in_exception_handler

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_exception_in_exception_handler(self):

        ext = ErrorHandlingExtension()
        self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1)
        self.assertFalse(ext.is_error_handled)

        def exception_handler(trainer, exp, tb):
            raise ValueError('hogehoge from exception handler')

        @training.make_extension(trigger=(1, 'iteration'), priority=100,
                                 on_error=exception_handler)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        dummy_extension = DummyExtension(self)
        self.trainer.extend(dummy_extension)

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(ext.is_error_handled)
        self.assertTrue(dummy_extension.is_finalized) 
开发者ID:chainer,项目名称:chainer,代码行数:25,代码来源:test_trainer.py

示例4: adadelta_eps_decay

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def adadelta_eps_decay(eps_decay):
    """Extension to perform adadelta eps decay.

    Args:
        eps_decay (float): Decay rate of eps.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adadelta_eps_decay(trainer):
        _adadelta_eps_decay(trainer, eps_decay)

    return adadelta_eps_decay 
开发者ID:espnet,项目名称:espnet,代码行数:18,代码来源:asr_utils.py

示例5: adam_lr_decay

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def adam_lr_decay(eps_decay):
    """Extension to perform adam lr decay.

    Args:
        eps_decay (float): Decay rate of lr.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adam_lr_decay(trainer):
        _adam_lr_decay(trainer, eps_decay)

    return adam_lr_decay 
开发者ID:espnet,项目名称:espnet,代码行数:18,代码来源:asr_utils.py

示例6: snapshot_object

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def snapshot_object(target, filename):
    """Returns a trainer extension to take snapshots of a given object.

    Args:
        target (model): Object to serialize.
        filename (str): Name of the file into which the object is serialized.It can
            be a format string, where the trainer object is passed to
            the :meth: `str.format` method. For example,
            ``'snapshot_{.updater.iteration}'`` is converted to
            ``'snapshot_10000'`` at the 10,000th iteration.

    Returns:
        An extension function.

    """

    @extension.make_extension(trigger=(1, "epoch"), priority=-100)
    def snapshot_object(trainer):
        torch_save(os.path.join(trainer.out, filename.format(trainer)), target)

    return snapshot_object 
开发者ID:espnet,项目名称:espnet,代码行数:23,代码来源:asr_utils.py

示例7: test_add_make_extension

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_make_extension(self):
        self.is_called = False

        @training.make_extension()
        def dummy_extension(trainer):
            self.is_called = True

        self.trainer.extend(dummy_extension)
        self.trainer.run()
        self.assertTrue(self.is_called) 
开发者ID:chainer,项目名称:chainer,代码行数:12,代码来源:test_trainer.py

示例8: test_add_make_extension_with_initializer

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_make_extension_with_initializer(self):
        self.is_called = False

        def initializer(trainer):
            trainer.is_initialized = True

        @training.make_extension(initializer=initializer)
        def dummy_extension(trainer):
            self.assertTrue(trainer.is_initialized)
            self.is_called = True

        self.trainer.extend(dummy_extension)
        self.trainer.run()
        self.assertTrue(self.is_called) 
开发者ID:chainer,项目名称:chainer,代码行数:16,代码来源:test_trainer.py

示例9: test_add_two_extensions_default_priority

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_two_extensions_default_priority(self):
        self.called_order = []

        @training.make_extension(trigger=(1, 'epoch'))
        def dummy_extension_1(trainer):
            self.called_order.append(1)

        @training.make_extension(trigger=(1, 'epoch'))
        def dummy_extension_2(trainer):
            self.called_order.append(2)

        self.trainer.extend(dummy_extension_1)
        self.trainer.extend(dummy_extension_2)
        self.trainer.run()
        self.assertEqual(self.called_order, [1, 2]) 
开发者ID:chainer,项目名称:chainer,代码行数:17,代码来源:test_trainer.py

示例10: test_add_two_extensions_specific_priority

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_add_two_extensions_specific_priority(self):
        self.called_order = []

        @training.make_extension(trigger=(1, 'epoch'), priority=50)
        def dummy_extension_1(trainer):
            self.called_order.append(1)

        @training.make_extension(trigger=(1, 'epoch'), priority=100)
        def dummy_extension_2(trainer):
            self.called_order.append(2)

        self.trainer.extend(dummy_extension_1)
        self.trainer.extend(dummy_extension_2)
        self.trainer.run()
        self.assertEqual(self.called_order, [2, 1]) 
开发者ID:chainer,项目名称:chainer,代码行数:17,代码来源:test_trainer.py

示例11: test_make_extension

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension(self):
        def initialize(trainer):
            pass

        @training.make_extension(trigger=(2, 'epoch'), default_name='my_ext',
                                 priority=50, initializer=initialize)
        def my_extension(trainer):
            pass

        self.assertEqual(my_extension.trigger, (2, 'epoch'))
        self.assertEqual(my_extension.default_name, 'my_ext')
        self.assertEqual(my_extension.priority, 50)
        self.assertIs(my_extension.initialize, initialize) 
开发者ID:chainer,项目名称:chainer,代码行数:15,代码来源:test_extension.py

示例12: test_make_extension_default_values

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension_default_values(self):
        @training.make_extension()
        def my_extension(trainer):
            pass

        self.assertEqual(my_extension.trigger, (1, 'iteration'))
        self.assertEqual(my_extension.default_name, 'my_extension')
        self.assertEqual(my_extension.priority, training.PRIORITY_READER)
        self.assertIsNone(my_extension.initialize) 
开发者ID:chainer,项目名称:chainer,代码行数:11,代码来源:test_extension.py

示例13: test_make_extension_deleted_argument

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension_deleted_argument(self):
        with self.assertRaises(ValueError):
            @training.make_extension(invoke_before_training=False)
            def my_extension(_):
                pass 
开发者ID:chainer,项目名称:chainer,代码行数:7,代码来源:test_extension.py

示例14: test_make_extension_unexpected_kwargs

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def test_make_extension_unexpected_kwargs(self):
        with self.assertRaises(TypeError):
            @training.make_extension(foo=1)
            def my_extension(_):
                pass 
开发者ID:chainer,项目名称:chainer,代码行数:7,代码来源:test_extension.py

示例15: restore_snapshot

# 需要导入模块: from chainer import training [as 别名]
# 或者: from chainer.training import make_extension [as 别名]
def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz):
    """Extension to restore snapshot.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def restore_snapshot(trainer):
        _restore_snapshot(model, snapshot, load_fn)

    return restore_snapshot 
开发者ID:espnet,项目名称:espnet,代码行数:15,代码来源:asr_utils.py


注:本文中的chainer.training.make_extension方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。