本文整理汇总了Python中ignite.handlers.EarlyStopping方法的典型用法代码示例。如果您正苦于以下问题:Python handlers.EarlyStopping方法的具体用法?Python handlers.EarlyStopping怎么用?Python handlers.EarlyStopping使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类ignite.handlers
的用法示例。
在下文中一共展示了handlers.EarlyStopping方法的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __call__
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def __call__(self, engine: Engine) -> None:
score = self.score_function(engine)
if self.best_score is None:
self.best_score = score
elif score <= self.best_score + self.min_delta:
if not self.cumulative_delta and score > self.best_score:
self.best_score = score
self.counter += 1
self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
if self.counter >= self.patience:
self.logger.info("EarlyStopping: Stop training")
self.trainer.terminate()
else:
self.best_score = score
self.counter = 0
示例2: add_early_stopping_by_val_score
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name):
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
Args:
patience (int): number of events to wait if no improvement and then stop the training.
evaluator (Engine): evaluation engine used to provide the score
trainer (Engine): trainer engine to stop the run if no improvement.
metric_name (str): metric name to use for score evaluation. This metric should be present in
`evaluator.state.metrics`.
Returns:
A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler.
"""
es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, es_handler)
return es_handler
示例3: test_simple_early_stopping
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_simple_early_stopping():
scores = iter([1.0, 0.8, 0.88])
def score_function(engine):
return next(scores)
trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
# Call 3 times and check if stopped
assert not trainer.should_terminate
h(None)
assert not trainer.should_terminate
h(None)
assert not trainer.should_terminate
h(None)
assert trainer.should_terminate
示例4: test_state_dict
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_state_dict():
scores = iter([1.0, 0.8, 0.88])
def score_function(engine):
return next(scores)
trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
# Call 3 times and check if stopped
assert not trainer.should_terminate
h(None)
assert not trainer.should_terminate
# Swap to new object, but maintain state
h2 = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
h2.load_state_dict(h.state_dict())
h2(None)
assert not trainer.should_terminate
h2(None)
assert trainer.should_terminate
示例5: test_early_stopping_on_last_event_delta
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_early_stopping_on_last_event_delta():
scores = iter([0.0, 0.3, 0.6])
trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(
patience=2, min_delta=0.4, cumulative_delta=False, score_function=lambda _: next(scores), trainer=trainer
)
assert not trainer.should_terminate
h(None) # counter == 0
assert not trainer.should_terminate
h(None) # delta == 0.3; counter == 1
assert not trainer.should_terminate
h(None) # delta == 0.3; counter == 2
assert trainer.should_terminate
示例6: test_early_stopping_on_cumulative_delta
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_early_stopping_on_cumulative_delta():
scores = iter([0.0, 0.3, 0.6])
trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(
patience=2, min_delta=0.4, cumulative_delta=True, score_function=lambda _: next(scores), trainer=trainer
)
assert not trainer.should_terminate
h(None) # counter == 0
assert not trainer.should_terminate
h(None) # delta == 0.3; counter == 1
assert not trainer.should_terminate
h(None) # delta == 0.6; counter == 0
assert not trainer.should_terminate
示例7: test_simple_no_early_stopping
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_simple_no_early_stopping():
scores = iter([1.0, 0.8, 1.2])
def score_function(engine):
return next(scores)
trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
# Call 3 times and check if not stopped
assert not trainer.should_terminate
h(None)
h(None)
h(None)
assert not trainer.should_terminate
示例8: test_with_engine_early_stopping
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_with_engine_early_stopping():
class Counter(object):
def __init__(self, count=0):
self.count = count
n_epochs_counter = Counter()
scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9])
def score_function(engine):
return next(scores)
trainer = Engine(do_nothing_update_fn)
evaluator = Engine(do_nothing_update_fn)
early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)
@trainer.on(Events.EPOCH_COMPLETED)
def evaluation(engine):
evaluator.run([0])
n_epochs_counter.count += 1
evaluator.add_event_handler(Events.COMPLETED, early_stopping)
trainer.run([0], max_epochs=10)
assert n_epochs_counter.count == 7
assert trainer.state.epoch == 7
示例9: test_with_engine_no_early_stopping
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_with_engine_no_early_stopping():
class Counter(object):
def __init__(self, count=0):
self.count = count
n_epochs_counter = Counter()
scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2])
def score_function(engine):
return next(scores)
trainer = Engine(do_nothing_update_fn)
evaluator = Engine(do_nothing_update_fn)
early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
@trainer.on(Events.EPOCH_COMPLETED)
def evaluation(engine):
evaluator.run([0])
n_epochs_counter.count += 1
evaluator.add_event_handler(Events.COMPLETED, early_stopping)
trainer.run([0], max_epochs=10)
assert n_epochs_counter.count == 10
assert trainer.state.epoch == 10
示例10: __call__
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def __call__(self, engine):
score = self.score_function(engine)
if self.best_score is None:
self.best_score = score
elif score < self.best_score:
self.counter += 1
self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
if self.counter >= self.patience:
self._logger.info("EarlyStopping: Stop training")
self.trainer.terminate()
else:
self.best_score = score
self.counter = 0
示例11: test_args_validation
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_args_validation():
trainer = Engine(do_nothing_update_fn)
with pytest.raises(ValueError, match=r"Argument patience should be positive integer."):
EarlyStopping(patience=-1, score_function=lambda engine: 0, trainer=trainer)
with pytest.raises(ValueError, match=r"Argument min_delta should not be a negative number."):
EarlyStopping(patience=2, min_delta=-0.1, score_function=lambda engine: 0, trainer=trainer)
with pytest.raises(TypeError, match=r"Argument score_function should be a function."):
EarlyStopping(patience=2, score_function=12345, trainer=trainer)
with pytest.raises(TypeError, match=r"Argument trainer should be an instance of Engine."):
EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)
示例12: test_simple_early_stopping_on_plateau
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def test_simple_early_stopping_on_plateau():
def score_function(engine):
return 42
trainer = Engine(do_nothing_update_fn)
h = EarlyStopping(patience=1, score_function=score_function, trainer=trainer)
# Call 2 times and check if stopped
assert not trainer.should_terminate
h(None)
assert not trainer.should_terminate
h(None)
assert trainer.should_terminate
示例13: _test_distrib_with_engine_early_stopping
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def _test_distrib_with_engine_early_stopping(device):
import torch.distributed as dist
torch.manual_seed(12)
class Counter(object):
def __init__(self, count=0):
self.count = count
n_epochs_counter = Counter()
scores = torch.tensor([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9], requires_grad=False).to(device)
def score_function(engine):
i = trainer.state.epoch - 1
v = scores[i]
dist.all_reduce(v)
v /= dist.get_world_size()
return v.item()
trainer = Engine(do_nothing_update_fn)
evaluator = Engine(do_nothing_update_fn)
early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)
@trainer.on(Events.EPOCH_COMPLETED)
def evaluation(engine):
evaluator.run([0])
n_epochs_counter.count += 1
evaluator.add_event_handler(Events.COMPLETED, early_stopping)
trainer.run([0], max_epochs=10)
assert trainer.state.epoch == 7
assert n_epochs_counter.count == 7
示例14: _test_distrib_integration_engine_early_stopping
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import EarlyStopping [as 别名]
def _test_distrib_integration_engine_early_stopping(device):
import torch.distributed as dist
from ignite.metrics import Accuracy
rank = dist.get_rank()
ws = dist.get_world_size()
torch.manual_seed(12)
n_epochs = 10
n_iters = 20
y_preds = (
[torch.randint(0, 2, size=(n_iters, ws)).to(device)]
+ [torch.ones(n_iters, ws).to(device)]
+ [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)]
)
y_true = (
[torch.randint(0, 2, size=(n_iters, ws)).to(device)]
+ [torch.ones(n_iters, ws).to(device)]
+ [torch.randint(0, 2, size=(n_iters, ws)).to(device) for _ in range(n_epochs - 2)]
)
def update(engine, _):
e = trainer.state.epoch - 1
i = engine.state.iteration - 1
return y_preds[e][i, rank], y_true[e][i, rank]
evaluator = Engine(update)
acc = Accuracy(device=device)
acc.attach(evaluator, "acc")
def score_function(engine):
return engine.state.metrics["acc"]
trainer = Engine(lambda e, b: None)
early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)
@trainer.on(Events.EPOCH_COMPLETED)
def evaluation(engine):
data = list(range(n_iters))
evaluator.run(data=data)
evaluator.add_event_handler(Events.COMPLETED, early_stopping)
trainer.run([0], max_epochs=10)
assert trainer.state.epoch == 5