当前位置: 首页>>代码示例>>Python>>正文


Python handlers.EarlyStopping方法代码示例

本文整理汇总了Python中ignite.handlers.EarlyStopping方法的典型用法代码示例。如果您正苦于以下问题:Python handlers.EarlyStopping方法的具体用法?Python handlers.EarlyStopping怎么用?Python handlers.EarlyStopping使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在ignite.handlers的用法示例。


在下文中一共展示了handlers.EarlyStopping方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __call__

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def __call__(self, engine: Engine) -> None:
        score = self.score_function(engine)

        if self.best_score is None:
            self.best_score = score
        elif score <= self.best_score + self.min_delta:
            if not self.cumulative_delta and score > self.best_score:
                self.best_score = score
            self.counter += 1
            self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
            if self.counter >= self.patience:
                self.logger.info("EarlyStopping: Stop training")
                self.trainer.terminate()
        else:
            self.best_score = score
            self.counter = 0 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:early_stopping.py

示例2: add_early_stopping_by_val_score

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name):
    """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.

    Args:
        patience (int): number of events to wait if no improvement and then stop the training.
        evaluator (Engine): evaluation engine used to provide the score
        trainer (Engine): trainer engine to stop the run if no improvement.
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.

    Returns:
        A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler.
    """
    es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, es_handler)

    return es_handler 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:common.py

示例3: test_simple_early_stopping

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_simple_early_stopping():

    scores = iter([1.0, 0.8, 0.88])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
    # Call 3 times and check if stopped
    assert not trainer.should_terminate
    h(None)
    assert not trainer.should_terminate
    h(None)
    assert not trainer.should_terminate
    h(None)
    assert trainer.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_early_stopping.py

示例4: test_state_dict

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_state_dict():

    scores = iter([1.0, 0.8, 0.88])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
    # Call 3 times and check if stopped
    assert not trainer.should_terminate
    h(None)
    assert not trainer.should_terminate

    # Swap to new object, but maintain state
    h2 = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
    h2.load_state_dict(h.state_dict())

    h2(None)
    assert not trainer.should_terminate
    h2(None)
    assert trainer.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:25,代码来源:test_early_stopping.py

示例5: test_early_stopping_on_last_event_delta

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_early_stopping_on_last_event_delta():

    scores = iter([0.0, 0.3, 0.6])

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(
        patience=2, min_delta=0.4, cumulative_delta=False, score_function=lambda _: next(scores), trainer=trainer
    )

    assert not trainer.should_terminate
    h(None)  # counter == 0
    assert not trainer.should_terminate
    h(None)  # delta == 0.3; counter == 1
    assert not trainer.should_terminate
    h(None)  # delta == 0.3; counter == 2
    assert trainer.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_early_stopping.py

示例6: test_early_stopping_on_cumulative_delta

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_early_stopping_on_cumulative_delta():

    scores = iter([0.0, 0.3, 0.6])

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(
        patience=2, min_delta=0.4, cumulative_delta=True, score_function=lambda _: next(scores), trainer=trainer
    )

    assert not trainer.should_terminate
    h(None)  # counter == 0
    assert not trainer.should_terminate
    h(None)  # delta == 0.3; counter == 1
    assert not trainer.should_terminate
    h(None)  # delta == 0.6; counter == 0
    assert not trainer.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_early_stopping.py

示例7: test_simple_no_early_stopping

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_simple_no_early_stopping():

    scores = iter([1.0, 0.8, 1.2])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
    # Call 3 times and check if not stopped
    assert not trainer.should_terminate
    h(None)
    h(None)
    h(None)
    assert not trainer.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:test_early_stopping.py

示例8: test_with_engine_early_stopping

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_with_engine_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 7
    assert trainer.state.epoch == 7 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_early_stopping.py

示例9: test_with_engine_no_early_stopping

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_with_engine_no_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 10
    assert trainer.state.epoch == 10 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_early_stopping.py

示例10: __call__

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def __call__(self, engine):
        score = self.score_function(engine)

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score:
            self.counter += 1
            self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
            if self.counter >= self.patience:
                self._logger.info("EarlyStopping: Stop training")
                self.trainer.terminate()
        else:
            self.best_score = score
            self.counter = 0 
开发者ID:hrhodin,项目名称:UnsupervisedGeometryAwareRepresentationLearning,代码行数:16,代码来源:early_stopping.py

示例11: test_args_validation

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_args_validation():

    trainer = Engine(do_nothing_update_fn)

    with pytest.raises(ValueError, match=r"Argument patience should be positive integer."):
        EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer)

    with pytest.raises(ValueError, match=r"Argument min_delta should not be a negative number."):
        EarlyStopping(patience=2, min_delta=-0.1, score_function=lambda engine: 0, trainer=trainer)

    with pytest.raises(TypeError, match=r"Argument score_function should be a function."):
        EarlyStopping(patience=2, score_function=12345, trainer=trainer)

    with pytest.raises(TypeError, match=r"Argument trainer should be an instance of Engine."):
        EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None) 
开发者ID:pytorch,项目名称:ignite,代码行数:17,代码来源:test_early_stopping.py

示例12: test_simple_early_stopping_on_plateau

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_simple_early_stopping_on_plateau():
    def score_function(engine):
        return 42

    trainer = Engine(do_nothing_update_fn)

    h = EarlyStopping(patience=1, score_function=score_function, trainer=trainer)
    # Call 2 times and check if stopped
    assert not trainer.should_terminate
    h(None)
    assert not trainer.should_terminate
    h(None)
    assert trainer.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:15,代码来源:test_early_stopping.py

示例13: _test_distrib_with_engine_early_stopping

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def _test_distrib_with_engine_early_stopping(device):

    import torch.distributed as dist

    torch.manual_seed(12)

    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = torch.tensor([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9], requires_grad=False).to(device)

    def score_function(engine):
        i = trainer.state.epoch - 1
        v = scores[i]
        dist.all_reduce(v)
        v /= dist.get_world_size()
        return v.item()

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert trainer.state.epoch == 7
    assert n_epochs_counter.count == 7 
开发者ID:pytorch,项目名称:ignite,代码行数:36,代码来源:test_early_stopping.py

示例14: _test_distrib_integration_engine_early_stopping

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def _test_distrib_integration_engine_early_stopping(device):

    import torch.distributed as dist
    from ignite.metrics import Accuracy

    rank = dist.get_rank()
    ws = dist.get_world_size()
    torch.manual_seed(12)

    n_epochs = 10
    n_iters = 20

    y_preds = (
        [torch.randint(0, 2, size=(n_iters, ws)).to(device)]
        + [torch.ones(n_iters, ws).to(device)]
        + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)]
    )

    y_true = (
        [torch.randint(0, 2, size=(n_iters, ws)).to(device)]
        + [torch.ones(n_iters, ws).to(device)]
        + [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)]
    )

    def update(engine, _):
        e = trainer.state.epoch - 1
        i = engine.state.iteration - 1
        return y_preds[e][i, rank], y_true[e][i, rank]

    evaluator = Engine(update)
    acc = Accuracy(device=device)
    acc.attach(evaluator, "acc")

    def score_function(engine):
        return engine.state.metrics["acc"]

    trainer = Engine(lambda e, b: None)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        data = list(range(n_iters))
        evaluator.run(data=data)

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert trainer.state.epoch == 5 
开发者ID:pytorch,项目名称:ignite,代码行数:49,代码来源:test_early_stopping.py


注:本文中的ignite.handlers.EarlyStopping方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。