当前位置: 首页>>代码示例>>Python>>正文


Python Events.ITERATION_STARTED属性代码示例

本文整理汇总了Python中ignite.engine.Events.ITERATION_STARTED属性的典型用法代码示例。如果您正苦于以下问题:Python Events.ITERATION_STARTED属性的具体用法?Python Events.ITERATION_STARTED怎么用?Python Events.ITERATION_STARTED使用的例子?那么, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在ignite.engine.Events的用法示例。


在下文中一共展示了Events.ITERATION_STARTED属性的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_terminate_epoch_stops_mid_epoch

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_terminate_epoch_stops_mid_epoch():
    num_iterations_per_epoch = 10
    iteration_to_stop = num_iterations_per_epoch + 4

    engine = Engine(MagicMock(return_value=1))

    def start_of_iteration_handler(engine):
        if engine.state.iteration == iteration_to_stop:
            engine.terminate_epoch()

    max_epochs = 3
    engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler)
    state = engine.run(data=[None] * num_iterations_per_epoch, max_epochs=max_epochs)
    # completes the iteration but doesn't increment counter (this happens just before a new iteration starts)
    true_value = num_iterations_per_epoch * (max_epochs - 1) + iteration_to_stop % num_iterations_per_epoch
    assert state.iteration == true_value 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:test_engine.py

示例2: test_callable_events_with_wrong_inputs

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_callable_events_with_wrong_inputs():

    with pytest.raises(ValueError, match=r"Only one of the input arguments should be specified"):
        Events.ITERATION_STARTED()

    with pytest.raises(ValueError, match=r"Only one of the input arguments should be specified"):
        Events.ITERATION_STARTED(event_filter="123", every=12)

    with pytest.raises(TypeError, match=r"Argument event_filter should be a callable"):
        Events.ITERATION_STARTED(event_filter="123")

    with pytest.raises(ValueError, match=r"Argument every should be integer and greater than zero"):
        Events.ITERATION_STARTED(every=-1)

    with pytest.raises(ValueError, match=r"but will be called with"):
        Events.ITERATION_STARTED(event_filter=lambda x: x) 
开发者ID:pytorch,项目名称:ignite,代码行数:18,代码来源:test_custom_events.py

示例3: test_list_of_events

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_list_of_events():
    def _test(event_list, true_iterations):

        engine = Engine(lambda e, b: b)

        iterations = []

        num_calls = [0]

        @engine.on(event_list)
        def execute_some_handler(e):
            iterations.append(e.state.iteration)
            num_calls[0] += 1

        engine.run(range(3), max_epochs=5)

        assert iterations == true_iterations
        assert num_calls[0] == len(true_iterations)

    _test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(once=1), [1, 1])
    _test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(once=10), [1, 10])
    _test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(every=3), [1, 3, 6, 9, 12, 15]) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_custom_events.py

示例4: test_optimizer_params

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_optimizer_params():

    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.writer.add_scalar.assert_called_once_with("lr/group_0", 0.01, 123)

    wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.writer.add_scalar.assert_called_once_with("generator/lr/group_0", 0.01, 123) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_tensorboard_logger.py

示例5: test_weights_scalar_handler_wrong_setup

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_weights_scalar_handler_wrong_setup():

    with pytest.raises(TypeError, match="Argument model should be of type torch.nn.Module"):
        WeightsScalarHandler(None)

    model = MagicMock(spec=torch.nn.Module)
    with pytest.raises(TypeError, match="Argument reduction should be callable"):
        WeightsScalarHandler(model, reduction=123)

    with pytest.raises(ValueError, match="Output of the reduction function should be a scalar"):
        WeightsScalarHandler(model, reduction=lambda x: x)

    wrapper = WeightsScalarHandler(model)
    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(RuntimeError, match="Handler 'WeightsScalarHandler' works only with TensorboardLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_tensorboard_logger.py

示例6: test_weights_scalar_handler_wrong_setup

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_weights_scalar_handler_wrong_setup():

    with pytest.raises(TypeError, match="Argument model should be of type torch.nn.Module"):
        WeightsScalarHandler(None)

    model = MagicMock(spec=torch.nn.Module)
    with pytest.raises(TypeError, match="Argument reduction should be callable"):
        WeightsScalarHandler(model, reduction=123)

    with pytest.raises(ValueError, match="Output of the reduction function should be a scalar"):
        WeightsScalarHandler(model, reduction=lambda x: x)

    wrapper = WeightsScalarHandler(model)
    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(RuntimeError, match="Handler 'WeightsScalarHandler' works only with VisdomLogger"):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED) 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_visdom_logger.py

示例7: test_optimizer_params

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_optimizer_params():

    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.trains_logger.report_scalar.assert_called_once_with(iteration=123, series="0", title="lr", value=0.01)

    wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.trains_logger.report_scalar.assert_called_once_with(
        iteration=123, series="0", title="generator/lr", value=0.01
    ) 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_trains_logger.py

示例8: test_output_handler_output_transform

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_output_handler_output_transform(dirname):

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.trains_logger.report_scalar.assert_called_once_with(
        iteration=123, series="output", title="tag", value=12345
    )

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.trains_logger.report_scalar.assert_called_once_with(
        iteration=123, series="loss", title="another_tag", value=12345
    ) 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_trains_logger.py

示例9: test_output_handler_output_transform

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_output_handler_output_transform():

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.log_metrics.assert_called_once_with(step=123, **{"tag/output": 12345})

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metrics.assert_called_once_with(step=123, **{"another_tag/loss": 12345}) 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_polyaxon_logger.py

示例10: test_optimizer_params

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_optimizer_params():

    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metrics.assert_called_once_with(**{"lr/group_0": 0.01, "step": 123})

    wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metrics.assert_called_once_with(**{"generator/lr/group_0": 0.01, "step": 123}) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_polyaxon_logger.py

示例11: test_pbar_batch_indeces

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_pbar_batch_indeces(capsys):
    engine = Engine(lambda e, b: time.sleep(0.1))

    @engine.on(Events.ITERATION_STARTED)
    def print_iter(_):
        print("iteration: ", engine.state.iteration)

    ProgressBar(persist=True).attach(engine)
    engine.run(list(range(4)), max_epochs=1)

    captured = capsys.readouterr()
    err = captured.err.split("\r")
    err = list(map(lambda x: x.strip(), err))
    err = list(filter(None, err))
    printed_batch_indeces = set(map(lambda x: int(x.split("/")[0][-1]), err))
    expected_batch_indeces = list(range(1, 5))
    assert sorted(list(printed_batch_indeces)) == expected_batch_indeces 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_tqdm_logger.py

示例12: test_pbar_wrong_events_order

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_pbar_wrong_events_order():

    engine = Engine(update_fn)
    pbar = ProgressBar()

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)

    with pytest.raises(ValueError, match="should be called before closing event"):
        pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)

    with pytest.raises(ValueError, match="should not be a filtered event"):
        pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10)) 
开发者ID:pytorch,项目名称:ignite,代码行数:24,代码来源:test_tqdm_logger.py

示例13: test_pbar_on_callable_events

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_pbar_on_callable_events(capsys):

    n_epochs = 1
    loader = list(range(100))
    engine = Engine(update_fn)

    pbar = ProgressBar()
    pbar.attach(engine, event_name=Events.ITERATION_STARTED(every=10), closing_event_name=Events.EPOCH_COMPLETED)
    engine.run(loader, max_epochs=n_epochs)

    captured = capsys.readouterr()
    err = captured.err.split("\r")
    err = list(map(lambda x: x.strip(), err))
    err = list(filter(None, err))
    actual = err[-1]
    expected = "Epoch: [90/100]  90%|█████████  [00:00<00:00]"
    assert actual == expected 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_tqdm_logger.py

示例14: test_optimizer_params

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_optimizer_params():
    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("lr/group_0", y=0.01, x=123)

    wrapper = OptimizerParamsHandler(optimizer, param_name="lr", tag="generator")
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("generator/lr/group_0", y=0.01, x=123) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_neptune_logger.py

示例15: test_output_handler_output_transform

# 需要导入模块: from ignite.engine import Events [as 别名]
# 或者: from ignite.engine.Events import ITERATION_STARTED [as 别名]
def test_output_handler_output_transform():
    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("tag/output", y=12345, x=123)

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("another_tag/loss", y=12345, x=123) 
开发者ID:pytorch,项目名称:ignite,代码行数:20,代码来源:test_neptune_logger.py


注:本文中的ignite.engine.Events.ITERATION_STARTED属性示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。