当前位置: 首页>>代码示例>>Python>>正文


Python Events.EPOCH_COMPLETED属性代码示例

本文整理汇总了Python中ignite.engine.Events.EPOCH_COMPLETED属性的典型用法代码示例。如果您正苦于以下问题:Python Events.EPOCH_COMPLETED属性的具体用法?Python Events.EPOCH_COMPLETED怎么用?Python Events.EPOCH_COMPLETED使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在ignite.engine.Events的用法示例。


在下文中一共展示了Events.EPOCH_COMPLETED属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_metrics_print

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_metrics_print(self):
        tempdir = tempfile.mkdtemp()
        shutil.rmtree(tempdir, ignore_errors=True)

        # set up engine
        def _train_func(engine, batch):
            return batch + 1.0

        engine = Engine(_train_func)

        # set up dummy metric
        @engine.on(Events.EPOCH_COMPLETED)
        def _update_metric(engine):
            current_metric = engine.state.metrics.get("acc", 0.1)
            engine.state.metrics["acc"] = current_metric + 0.1

        # set up testing handler
        stats_handler = TensorBoardStatsHandler(log_dir=tempdir)
        stats_handler.attach(engine)
        engine.run(range(3), max_epochs=2)
        # check logging output

        self.assertTrue(os.path.exists(tempdir))
        shutil.rmtree(tempdir) 
开发者ID:Project-MONAI,项目名称:MONAI,代码行数:26,代码来源:test_handler_tb_stats.py

示例2: attach

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def attach(self, engine, metric_names=None, output_transform=None):
        """
        Attaches the progress bar to an engine object.

        Args:
            engine (Engine): engine object.
            metric_names (list, optional): list of the metrics names to log as the bar progresses
            output_transform (callable, optional): a function to select what you want to print from the engine's
                output. This function may return either a dictionary with entries in the format of ``{name: value}``,
                or a single scalar, which will be displayed with the default name `output`.
        """
        if metric_names is not None and not isinstance(metric_names, list):
            raise TypeError("metric_names should be a list, got {} instead.".format(type(metric_names)))

        if output_transform is not None and not callable(output_transform):
            raise TypeError("output_transform should be a function, got {} instead."
                            .format(type(output_transform)))

        engine.add_event_handler(Events.ITERATION_COMPLETED, self._update, metric_names, output_transform)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self._close) 
开发者ID:leokarlin,项目名称:LaSO,代码行数:22,代码来源:tqdm_logger.py

示例3: test_terminate_at_end_of_epoch_stops_run

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_terminate_at_end_of_epoch_stops_run():
    max_epochs = 5
    last_epoch_to_run = 3

    engine = Engine(MagicMock(return_value=1))

    def end_of_epoch_handler(engine):
        if engine.state.epoch == last_epoch_to_run:
            engine.terminate()

    engine.add_event_handler(Events.EPOCH_COMPLETED, end_of_epoch_handler)

    assert not engine.should_terminate

    state = engine.run([1], max_epochs=max_epochs)

    assert state.epoch == last_epoch_to_run
    assert engine.should_terminate 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_engine.py

示例4: test_time_stored_in_state

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_time_stored_in_state():
    def _test(data, max_epochs, epoch_length):
        sleep_time = 0.01
        engine = Engine(lambda e, b: time.sleep(sleep_time))

        def check_epoch_time(engine):
            assert engine.state.times[Events.EPOCH_COMPLETED.name] >= sleep_time * epoch_length

        def check_completed_time(engine):
            assert engine.state.times[Events.COMPLETED.name] >= sleep_time * epoch_length * max_epochs

        engine.add_event_handler(Events.EPOCH_COMPLETED, lambda e: check_epoch_time(e))
        engine.add_event_handler(Events.COMPLETED, lambda e: check_completed_time(e))

        engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length)

    _test(list(range(100)), max_epochs=2, epoch_length=100)
    _test(list(range(200)), max_epochs=2, epoch_length=100)
    _test(list(range(200)), max_epochs=5, epoch_length=100) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_engine.py

示例5: test_has_handler_on_callable_events

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_has_handler_on_callable_events():
    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED)

    engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_custom_events.py

示例6: test_state_custom_attrs_init

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_state_custom_attrs_init():
    def _test(with_load_state_dict=False):
        engine = Engine(lambda e, b: None)
        engine.state.alpha = 0.0
        engine.state.beta = 1.0

        if with_load_state_dict:
            engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5})

        @engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED)
        def check_custom_attr():
            assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0
            assert hasattr(engine.state, "beta") and engine.state.beta == 1.0

        engine.run([0, 1, 2, 3, 4], max_epochs=5)

    _test()
    _test(with_load_state_dict=True) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_engine_state_dict.py

示例7: test_with_engine_early_stopping

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_with_engine_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.5, 0.9, 1.0, 0.99, 1.1, 0.9])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=3, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 7
    assert trainer.state.epoch == 7 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_early_stopping.py

示例8: test_with_engine_early_stopping_on_plateau

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_with_engine_early_stopping_on_plateau():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    def score_function(engine):
        return 0.047

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=4, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 5
    assert trainer.state.epoch == 5 
开发者ID:pytorch,项目名称:ignite,代码行数:25,代码来源:test_early_stopping.py

示例9: test_with_engine_no_early_stopping

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_with_engine_no_early_stopping():
    class Counter(object):
        def __init__(self, count=0):
            self.count = count

    n_epochs_counter = Counter()

    scores = iter([1.0, 0.8, 1.2, 1.23, 0.9, 1.0, 1.1, 1.253, 1.26, 1.2])

    def score_function(engine):
        return next(scores)

    trainer = Engine(do_nothing_update_fn)
    evaluator = Engine(do_nothing_update_fn)
    early_stopping = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluation(engine):
        evaluator.run([0])
        n_epochs_counter.count += 1

    evaluator.add_event_handler(Events.COMPLETED, early_stopping)
    trainer.run([0], max_epochs=10)
    assert n_epochs_counter.count == 10
    assert trainer.state.epoch == 10 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_early_stopping.py

示例10: test_save_best_model_by_val_score

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_save_best_model_by_val_score(dirname):

    trainer = Engine(lambda e, b: None)
    evaluator = Engine(lambda e, b: None)
    model = DummyModel()

    acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        evaluator.run([0, 1])

    @evaluator.on(Events.EPOCH_COMPLETED)
    def set_eval_metric(engine):
        engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}

    save_best_model_by_val_score(dirname, evaluator, model, metric_name="acc", n_saved=2, trainer=trainer)

    trainer.run([0, 1], max_epochs=len(acc_scores))

    assert set(os.listdir(dirname)) == {"best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"} 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_common.py

示例11: test_integration

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_integration(dirname):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    with pytest.warns(UserWarning, match="TrainsSaver: running in bypass mode"):
        TrainsLogger.set_bypass_mode(True)
        logger = TrainsLogger(output_uri=dirname)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            logger.trains_logger.report_scalar(title="", series="", value="test_value", iteration=global_step)

        logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        trainer.run(data, max_epochs=n_epochs)
        logger.close() 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_trains_logger.py

示例12: test_integration

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_integration():

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    plx_logger = PolyaxonLogger()

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        logger.log_metrics(step=global_step, **{"{}".format("test_value"): global_step})

    plx_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    trainer.run(data, max_epochs=n_epochs) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_polyaxon_logger.py

示例13: test_integration_as_context_manager

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_integration_as_context_manager():

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    with PolyaxonLogger() as plx_logger:

        trainer = Engine(update_fn)

        def dummy_handler(engine, logger, event_name):
            global_step = engine.state.get_event_attrib_value(event_name)
            logger.log_metrics(step=global_step, **{"{}".format("test_value"): global_step})

        plx_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

        trainer.run(data, max_epochs=n_epochs) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_polyaxon_logger.py

示例14: test_pbar_wrong_events_order

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)

    with pytest.raises(ValueError, match="should not be a filtered event"):
        pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10)) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_tqdm_logger.py

示例15: test_integration

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_COMPLETED [as 别名]
def test_integration():
    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    trainer = Engine(update_fn)

    npt_logger = NeptuneLogger(offline_mode=True)

    def dummy_handler(engine, logger, event_name):
        global_step = engine.state.get_event_attrib_value(event_name)
        logger.log_metric("test_value", global_step, global_step)

    npt_logger.attach(trainer, log_handler=dummy_handler, event_name=Events.EPOCH_COMPLETED)

    trainer.run(data, max_epochs=n_epochs)
    npt_logger.close() 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_neptune_logger.py


注:本文中的ignite.engine.Events.EPOCH_COMPLETED属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。