當前位置: 首頁>>代碼示例>>Python>>正文


Python training.make_extension方法代碼示例

本文整理匯總了Python中chainer.training.make_extension方法的典型用法代碼示例。如果您正苦於以下問題:Python training.make_extension方法的具體用法?Python training.make_extension怎麽用?Python training.make_extension使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在chainer.training的用法示例。


在下文中一共展示了training.make_extension方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: get_model

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def get_model(args, loss_func, vocab, vocab_ngram_tokens, current_utils=utils.word):
    model = None
    if args.subword == 'none':
        if args.model == 'skipgram':
            model = current_utils.SkipGram(vocab.cnt_words, args.dimensions, loss_func)
        if args.model == 'cbow':
            # todo only skipgram supported
            model = current_utils.ContinuousBoW(vocab.cnt_words, args.dimensions, loss_func)
    else:
        if args.model == 'skipgram':
            model = utils.subword.SkipGram(args.subword, vocab, vocab_ngram_tokens, args.dimensions, loss_func, )

    if model is None:
        raise Exception('Unknown model and word/subword type: {} "and" {}'.format(args.model, args.subword))
    return model


#@training.make_extension(trigger=(1, 'epoch'))
#def dump_embs(trainer):
#    print("dumping embeddings") 
開發者ID:vecto-ai,項目名稱:vecto,代碼行數:22,代碼來源:train_word2vec.py

示例2: test_on_error

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_on_error(self):

        class TheOnlyError(Exception):
            pass

        @training.make_extension(trigger=(1, 'iteration'), priority=100)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        snapshot = extensions.snapshot_object(self.trainer, self.filename,
                                              snapshot_on_error=True)
        self.trainer.extend(snapshot)

        self.assertFalse(os.path.exists(self.filename))

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(os.path.exists(self.filename)) 
開發者ID:chainer,項目名稱:chainer,代碼行數:22,代碼來源:test_snapshot.py

示例3: test_exception_in_exception_handler

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_exception_in_exception_handler(self):

        ext = ErrorHandlingExtension()
        self.trainer.extend(ext, trigger=(1, 'iteration'), priority=1)
        self.assertFalse(ext.is_error_handled)

        def exception_handler(trainer, exp, tb):
            raise ValueError('hogehoge from exception handler')

        @training.make_extension(trigger=(1, 'iteration'), priority=100,
                                 on_error=exception_handler)
        def exception_raiser(trainer):
            raise TheOnlyError()
        self.trainer.extend(exception_raiser)

        dummy_extension = DummyExtension(self)
        self.trainer.extend(dummy_extension)

        with self.assertRaises(TheOnlyError):
            self.trainer.run()

        self.assertTrue(ext.is_error_handled)
        self.assertTrue(dummy_extension.is_finalized) 
開發者ID:chainer,項目名稱:chainer,代碼行數:25,代碼來源:test_trainer.py

示例4: adadelta_eps_decay

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def adadelta_eps_decay(eps_decay):
    """Extension to perform adadelta eps decay.

    Args:
        eps_decay (float): Decay rate of eps.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adadelta_eps_decay(trainer):
        _adadelta_eps_decay(trainer, eps_decay)

    return adadelta_eps_decay 
開發者ID:espnet,項目名稱:espnet,代碼行數:18,代碼來源:asr_utils.py

示例5: adam_lr_decay

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def adam_lr_decay(eps_decay):
    """Extension to perform adam lr decay.

    Args:
        eps_decay (float): Decay rate of lr.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def adam_lr_decay(trainer):
        _adam_lr_decay(trainer, eps_decay)

    return adam_lr_decay 
開發者ID:espnet,項目名稱:espnet,代碼行數:18,代碼來源:asr_utils.py

示例6: snapshot_object

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def snapshot_object(target, filename):
    """Returns a trainer extension to take snapshots of a given object.

    Args:
        target (model): Object to serialize.
        filename (str): Name of the file into which the object is serialized.It can
            be a format string, where the trainer object is passed to
            the :meth: `str.format` method. For example,
            ``'snapshot_{.updater.iteration}'`` is converted to
            ``'snapshot_10000'`` at the 10,000th iteration.

    Returns:
        An extension function.

    """

    @extension.make_extension(trigger=(1, "epoch"), priority=-100)
    def snapshot_object(trainer):
        torch_save(os.path.join(trainer.out, filename.format(trainer)), target)

    return snapshot_object 
開發者ID:espnet,項目名稱:espnet,代碼行數:23,代碼來源:asr_utils.py

示例7: test_add_make_extension

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_add_make_extension(self):
        self.is_called = False

        @training.make_extension()
        def dummy_extension(trainer):
            self.is_called = True

        self.trainer.extend(dummy_extension)
        self.trainer.run()
        self.assertTrue(self.is_called) 
開發者ID:chainer,項目名稱:chainer,代碼行數:12,代碼來源:test_trainer.py

示例8: test_add_make_extension_with_initializer

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_add_make_extension_with_initializer(self):
        self.is_called = False

        def initializer(trainer):
            trainer.is_initialized = True

        @training.make_extension(initializer=initializer)
        def dummy_extension(trainer):
            self.assertTrue(trainer.is_initialized)
            self.is_called = True

        self.trainer.extend(dummy_extension)
        self.trainer.run()
        self.assertTrue(self.is_called) 
開發者ID:chainer,項目名稱:chainer,代碼行數:16,代碼來源:test_trainer.py

示例9: test_add_two_extensions_default_priority

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_add_two_extensions_default_priority(self):
        self.called_order = []

        @training.make_extension(trigger=(1, 'epoch'))
        def dummy_extension_1(trainer):
            self.called_order.append(1)

        @training.make_extension(trigger=(1, 'epoch'))
        def dummy_extension_2(trainer):
            self.called_order.append(2)

        self.trainer.extend(dummy_extension_1)
        self.trainer.extend(dummy_extension_2)
        self.trainer.run()
        self.assertEqual(self.called_order, [1, 2]) 
開發者ID:chainer,項目名稱:chainer,代碼行數:17,代碼來源:test_trainer.py

示例10: test_add_two_extensions_specific_priority

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_add_two_extensions_specific_priority(self):
        self.called_order = []

        @training.make_extension(trigger=(1, 'epoch'), priority=50)
        def dummy_extension_1(trainer):
            self.called_order.append(1)

        @training.make_extension(trigger=(1, 'epoch'), priority=100)
        def dummy_extension_2(trainer):
            self.called_order.append(2)

        self.trainer.extend(dummy_extension_1)
        self.trainer.extend(dummy_extension_2)
        self.trainer.run()
        self.assertEqual(self.called_order, [2, 1]) 
開發者ID:chainer,項目名稱:chainer,代碼行數:17,代碼來源:test_trainer.py

示例11: test_make_extension

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_make_extension(self):
        def initialize(trainer):
            pass

        @training.make_extension(trigger=(2, 'epoch'), default_name='my_ext',
                                 priority=50, initializer=initialize)
        def my_extension(trainer):
            pass

        self.assertEqual(my_extension.trigger, (2, 'epoch'))
        self.assertEqual(my_extension.default_name, 'my_ext')
        self.assertEqual(my_extension.priority, 50)
        self.assertIs(my_extension.initialize, initialize) 
開發者ID:chainer,項目名稱:chainer,代碼行數:15,代碼來源:test_extension.py

示例12: test_make_extension_default_values

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_make_extension_default_values(self):
        @training.make_extension()
        def my_extension(trainer):
            pass

        self.assertEqual(my_extension.trigger, (1, 'iteration'))
        self.assertEqual(my_extension.default_name, 'my_extension')
        self.assertEqual(my_extension.priority, training.PRIORITY_READER)
        self.assertIsNone(my_extension.initialize) 
開發者ID:chainer,項目名稱:chainer,代碼行數:11,代碼來源:test_extension.py

示例13: test_make_extension_deleted_argument

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_make_extension_deleted_argument(self):
        with self.assertRaises(ValueError):
            @training.make_extension(invoke_before_training=False)
            def my_extension(_):
                pass 
開發者ID:chainer,項目名稱:chainer,代碼行數:7,代碼來源:test_extension.py

示例14: test_make_extension_unexpected_kwargs

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def test_make_extension_unexpected_kwargs(self):
        with self.assertRaises(TypeError):
            @training.make_extension(foo=1)
            def my_extension(_):
                pass 
開發者ID:chainer,項目名稱:chainer,代碼行數:7,代碼來源:test_extension.py

示例15: restore_snapshot

# 需要導入模塊: from chainer import training [as 別名]
# 或者: from chainer.training import make_extension [as 別名]
def restore_snapshot(model, snapshot, load_fn=chainer.serializers.load_npz):
    """Extension to restore snapshot.

    Returns:
        An extension function.

    """

    @training.make_extension(trigger=(1, "epoch"))
    def restore_snapshot(trainer):
        _restore_snapshot(model, snapshot, load_fn)

    return restore_snapshot 
開發者ID:espnet,項目名稱:espnet,代碼行數:15,代碼來源:asr_utils.py


注:本文中的chainer.training.make_extension方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。