当前位置: 首页>>代码示例>>Python>>正文


Python Events.EPOCH_STARTED属性代码示例

本文整理汇总了Python中ignite.engine.Events.EPOCH_STARTED属性的典型用法代码示例。如果您正苦于以下问题:Python Events.EPOCH_STARTED属性的具体用法?Python Events.EPOCH_STARTED怎么用?Python Events.EPOCH_STARTED使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在ignite.engine.Events的用法示例。


在下文中一共展示了Events.EPOCH_STARTED属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: attach

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def attach(self, engine, name):
        """ Register callbacks to control the search for learning rate.

        Args:
            engine (ignite.engine.Engine):
                Engine that this handler will be attached to

        Returns:
            self (Timer)

        """

        engine.add_event_handler(Events.EPOCH_STARTED, self.started)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
        engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name)

        self.engine = engine

        return self 
开发者ID:leokarlin,项目名称:LaSO,代码行数:21,代码来源:find_learning_rate.py

示例2: test_has_handler_on_callable_events

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_has_handler_on_callable_events():
    engine = Engine(lambda e, b: 1)

    def foo(e):
        pass

    assert not engine.has_event_handler(foo)

    engine.add_event_handler(Events.EPOCH_STARTED, foo)
    assert engine.has_event_handler(foo)

    def bar(e):
        pass

    engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar)
    assert engine.has_event_handler(bar)
    assert engine.has_event_handler(bar, Events.EPOCH_COMPLETED)

    engine.has_event_handler(bar, Events.EPOCH_COMPLETED(every=3)) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_custom_events.py

示例3: test_state_custom_attrs_init

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_state_custom_attrs_init():
    def _test(with_load_state_dict=False):
        engine = Engine(lambda e, b: None)
        engine.state.alpha = 0.0
        engine.state.beta = 1.0

        if with_load_state_dict:
            engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5})

        @engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED)
        def check_custom_attr():
            assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0
            assert hasattr(engine.state, "beta") and engine.state.beta == 1.0

        engine.run([0, 1, 2, 3, 4], max_epochs=5)

    _test()
    _test(with_load_state_dict=True) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_engine_state_dict.py

示例4: test_has_event_handler

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_has_event_handler():
    engine = DummyEngine()
    handlers = [MagicMock(spec_set=True), MagicMock(spec_set=True)]
    m = MagicMock(spec_set=True)
    for handler in handlers:
        engine.add_event_handler(Events.STARTED, handler)
    engine.add_event_handler(Events.COMPLETED, m)

    for handler in handlers:
        assert engine.has_event_handler(handler, Events.STARTED)
        assert engine.has_event_handler(handler)
        assert not engine.has_event_handler(handler, Events.COMPLETED)
        assert not engine.has_event_handler(handler, Events.EPOCH_STARTED)

    assert not engine.has_event_handler(m, Events.STARTED)
    assert engine.has_event_handler(m, Events.COMPLETED)
    assert engine.has_event_handler(m)
    assert not engine.has_event_handler(m, Events.EPOCH_STARTED) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_event_handlers.py

示例5: test_output_handler_both

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_output_handler_both():

    wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.writer.add_scalar.call_count == 3
    mock_logger.writer.add_scalar.assert_has_calls(
        [call("tag/a", 12.23, 5), call("tag/b", 23.45, 5), call("tag/loss", 12345, 5)], any_order=True
    ) 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_tensorboard_logger.py

示例6: test_output_handler_with_global_step_transform

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.writer.add_scalar.call_count == 1
    mock_logger.writer.add_scalar.assert_has_calls([call("tag/loss", 12345, 10)]) 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:test_tensorboard_logger.py

示例7: test_weights_scalar_handler_frozen_layers

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_weights_scalar_handler_frozen_layers(dummy_model_factory):

    model = dummy_model_factory(with_grads=True, with_frozen_layer=True)

    wrapper = WeightsScalarHandler(model)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.writer.add_scalar.assert_has_calls(
        [call("weights_norm/fc2/weight", 12.0, 5), call("weights_norm/fc2/bias", math.sqrt(12.0), 5),], any_order=True
    )

    with pytest.raises(AssertionError):
        mock_logger.writer.add_scalar.assert_has_calls(
            [call("weights_norm/fc1/weight", 12.0, 5), call("weights_norm/fc1/bias", math.sqrt(12.0), 5),],
            any_order=True,
        )

    assert mock_logger.writer.add_scalar.call_count == 2 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_tensorboard_logger.py

示例8: test_grads_scalar_handler_frozen_layers

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_grads_scalar_handler_frozen_layers(dummy_model_factory, norm_mock):
    model = dummy_model_factory(with_grads=True, with_frozen_layer=True)

    wrapper = GradsScalarHandler(model, reduction=norm_mock)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    norm_mock.reset_mock()

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.writer.add_scalar.assert_has_calls(
        [call("grads_norm/fc2/weight", ANY, 5), call("grads_norm/fc2/bias", ANY, 5),], any_order=True
    )

    with pytest.raises(AssertionError):
        mock_logger.writer.add_scalar.assert_has_calls(
            [call("grads_norm/fc1/weight", ANY, 5), call("grads_norm/fc1/bias", ANY, 5),], any_order=True
        )
    assert mock_logger.writer.add_scalar.call_count == 2
    assert norm_mock.call_count == 2 
开发者ID:pytorch,项目名称:ignite,代码行数:26,代码来源:test_tensorboard_logger.py

示例9: test_output_handler_with_wrong_global_step_transform_output

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_output_handler_with_wrong_global_step_transform_output():
    def global_step_transform(*args, **kwargs):
        return "a"

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    with pytest.raises(TypeError, match="global_step must be int"):
        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:test_visdom_logger.py

示例10: test_output_handler_both

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_output_handler_both(dirname):

    wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.trains_logger.report_scalar.call_count == 3
    mock_logger.trains_logger.report_scalar.assert_has_calls(
        [
            call(title="tag", series="a", iteration=5, value=12.23),
            call(title="tag", series="b", iteration=5, value=23.45),
            call(title="tag", series="loss", iteration=5, value=12345),
        ],
        any_order=True,
    ) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_trains_logger.py

示例11: test_output_handler_with_global_step_transform

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler("tag", output_transform=lambda x: {"loss": x}, global_step_transform=global_step_transform)
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.trains_logger.report_scalar.call_count == 1
    mock_logger.trains_logger.report_scalar.assert_has_calls(
        [call(title="tag", series="loss", iteration=10, value=12345)]
    ) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_trains_logger.py

示例12: test_pbar_on_epochs

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_pbar_on_epochs(capsys):

    n_epochs = 10
    loader = [1, 2, 3, 4, 5]
    engine = Engine(update_fn)

    pbar = ProgressBar()
    pbar.attach(engine, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
    engine.run(loader, max_epochs=n_epochs)

    captured = capsys.readouterr()
    err = captured.err.split("\r")
    err = list(map(lambda x: x.strip(), err))
    err = list(filter(None, err))
    actual = err[-1]
    expected = "Epoch: [9/10]  90%|█████████  [00:00<00:00]"
    assert actual == expected 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_tqdm_logger.py

示例13: test_weights_scalar_handler_frozen_layers

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_weights_scalar_handler_frozen_layers(dummy_model_factory):
    model = dummy_model_factory(with_grads=True, with_frozen_layer=True)

    wrapper = WeightsScalarHandler(model)
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.log_metric.assert_has_calls(
        [call("weights_norm/fc2/weight", y=12.0, x=5), call("weights_norm/fc2/bias", y=math.sqrt(12.0), x=5),],
        any_order=True,
    )

    with pytest.raises(AssertionError):
        mock_logger.log_metric.assert_has_calls(
            [call("weights_norm/fc1/weight", y=12.0, x=5), call("weights_norm/fc1/bias", y=math.sqrt(12.0), x=5),],
            any_order=True,
        )

    assert mock_logger.log_metric.call_count == 2 
开发者ID:pytorch,项目名称:ignite,代码行数:26,代码来源:test_neptune_logger.py

示例14: test_grads_scalar_handler_frozen_layers

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_grads_scalar_handler_frozen_layers(dummy_model_factory, norm_mock):
    model = dummy_model_factory(with_grads=True, with_frozen_layer=True)

    wrapper = GradsScalarHandler(model, reduction=norm_mock)
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    norm_mock.reset_mock()

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.log_metric.assert_has_calls(
        [call("grads_norm/fc2/weight", y=ANY, x=5), call("grads_norm/fc2/bias", y=ANY, x=5),], any_order=True
    )

    with pytest.raises(AssertionError):
        mock_logger.log_metric.assert_has_calls(
            [call("grads_norm/fc1/weight", y=ANY, x=5), call("grads_norm/fc1/bias", y=ANY, x=5),], any_order=True
        )
    assert mock_logger.log_metric.call_count == 2
    assert norm_mock.call_count == 2 
开发者ID:pytorch,项目名称:ignite,代码行数:25,代码来源:test_neptune_logger.py

示例15: test_event_handler_epoch_started

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import EPOCH_STARTED [as 别名]
def test_event_handler_epoch_started():
    true_event_handler_time = 0.1
    true_max_epochs = 2
    true_num_iters = 1

    profiler = BasicTimeProfiler()
    dummy_trainer = Engine(_do_nothing_update_fn)
    profiler.attach(dummy_trainer)

    @dummy_trainer.on(Events.EPOCH_STARTED)
    def delay_epoch_start(engine):
        time.sleep(true_event_handler_time)

    dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
    results = profiler.get_results()
    event_results = results["event_handlers_stats"]["EPOCH_STARTED"]

    assert event_results["min/index"][0] == approx(true_event_handler_time, abs=1e-1)
    assert event_results["max/index"][0] == approx(true_event_handler_time, abs=1e-1)
    assert event_results["mean"] == approx(true_event_handler_time, abs=1e-1)
    assert event_results["std"] == approx(0.0, abs=1e-1)
    assert event_results["total"] == approx(true_max_epochs * true_event_handler_time, abs=1e-1) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_time_profilers.py


注:本文中的ignite.engine.Events.EPOCH_STARTED属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。