当前位置: 首页>>代码示例>>Python>>正文


Python handlers.ModelCheckpoint方法代码示例

本文整理汇总了Python中ignite.handlers.ModelCheckpoint方法的典型用法代码示例。如果您正苦于以下问题:Python handlers.ModelCheckpoint方法的具体用法?Python handlers.ModelCheckpoint怎么用?Python handlers.ModelCheckpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在ignite.handlers的用法示例。


在下文中一共展示了handlers.ModelCheckpoint方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: test_last_k

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_last_k(dirname):

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    h(engine, to_save)

    for i in range(1, 9):
        engine.state.iteration = i
        h(engine, to_save)

    expected = ["{}_{}_{}.pt".format(_PREFIX, "model", i) for i in [7, 8]]

    assert sorted(os.listdir(dirname)) == expected, "{} vs {}".format(sorted(os.listdir(dirname)), expected) 
开发者ID:pytorch,项目名称:ignite,代码行数:19,代码来源:test_checkpoint.py

示例2: test_disabled_n_saved

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_disabled_n_saved(dirname):

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=None)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}

    num_iters = 100
    for i in range(num_iters):
        engine.state.iteration = i
        h(engine, to_save)

    saved_files = sorted(os.listdir(dirname))
    assert len(saved_files) == num_iters, "{}".format(saved_files)

    expected = sorted(["{}_{}_{}.pt".format(_PREFIX, "model", i) for i in range(num_iters)])
    assert saved_files == expected, "{} vs {}".format(saved_files, expected) 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_checkpoint.py

示例3: test_best_k

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_best_k(dirname):
    scores = iter([1.2, -2.0, 3.1, -4.0])

    def score_function(_):
        return next(scores)

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(4):
        h(engine, to_save)

    expected = ["{}_{}_{:.4f}.pt".format(_PREFIX, "model", i) for i in [1.2, 3.1]]

    assert sorted(os.listdir(dirname)) == expected 
开发者ID:pytorch,项目名称:ignite,代码行数:21,代码来源:test_checkpoint.py

示例4: test_best_k_with_suffix

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_best_k_with_suffix(dirname):
    scores = [0.3456789, 0.1234, 0.4567, 0.134567]
    scores_iter = iter(scores)

    def score_function(engine):
        return next(scores_iter)

    h = ModelCheckpoint(
        dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, score_name="val_loss"
    )

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(4):
        engine.state.epoch += 1
        h(engine, to_save)

    expected = ["{}_{}_val_loss={:.4}.pt".format(_PREFIX, "model", scores[e - 1]) for e in [1, 3]]

    assert sorted(os.listdir(dirname)) == expected 
开发者ID:pytorch,项目名称:ignite,代码行数:25,代码来源:test_checkpoint.py

示例5: test_removes_each_score_at_most_once

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_removes_each_score_at_most_once(dirname):
    scores = [0, 1, 1, 2, 3]
    scores_iter = iter(scores)

    def score_function(_):
        return next(scores_iter)

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {"model": model}
    for _ in range(len(scores)):
        h(engine, to_save)

    # If a score was removed multiple times, the code above would have raise a
    # FileNotFoundError. So this just tests the absence of such a failure
    # without futher assertions. 
开发者ID:pytorch,项目名称:ignite,代码行数:22,代码来源:test_checkpoint.py

示例6: test_load_checkpoint_with_different_num_classes

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_load_checkpoint_with_different_num_classes(dirname):
    model = DummyPretrainedModel()
    to_save_single_object = {"model": model}

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    handler(trainer, to_save_single_object)

    fname = handler.last_checkpoint
    loaded_checkpoint = torch.load(fname)

    to_load_single_object = {"pretrained_features": model.features}

    with pytest.raises(RuntimeError):
        Checkpoint.load_objects(to_load_single_object, loaded_checkpoint)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        Checkpoint.load_objects(to_load_single_object, loaded_checkpoint, strict=False, blah="blah")

    loaded_weights = to_load_single_object["pretrained_features"].state_dict()["weight"]

    assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights)) 
开发者ID:pytorch,项目名称:ignite,代码行数:27,代码来源:test_checkpoint.py

示例7: _test_tpu_saves_to_cpu

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def _test_tpu_saves_to_cpu(device, dirname):
    torch.manual_seed(0)

    h = ModelCheckpoint(dirname, _PREFIX)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=1)

    model = DummyModel().to(device)
    to_save = {"model": model}

    h(engine, to_save)

    idist.barrier()

    fname = h.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    assert loaded_objects == model.cpu().state_dict() 
开发者ID:pytorch,项目名称:ignite,代码行数:22,代码来源:test_checkpoint.py

示例8: _create_checkpoint_handler

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def _create_checkpoint_handler(self):
        return ModelCheckpoint(
            self._model_save_location,
            self._running_model_prefix,
            score_function=self._score_function,
            n_saved=1,
            create_dir=True,
            save_as_state_dict=True,
            require_empty=False,
        ) 
开发者ID:microsoft,项目名称:seismic-deeplearning,代码行数:12,代码来源:__init__.py

示例9: setup_checkpoint

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def setup_checkpoint(self, base_model, classifier, setops_model, evaluator):
        """Save checkpoints of the models."""

        checkpoint_handler_acc = ModelCheckpoint(
            self.results_path,
            CKPT_PREFIX,
            score_function=lambda eng: round(
                (eng.state.metrics["fake class acc"] + eng.state.metrics["S class acc"] +
                 eng.state.metrics["I class acc"] + eng.state.metrics["U class acc"]) / 4,
                3
            ),
            score_name="val_acc",
            n_saved=2,
            require_empty=False
        )
        checkpoint_handler_last = ModelCheckpoint(
            self.results_path,
            CKPT_PREFIX,
            save_interval=2,
            n_saved=2,
            require_empty=False
        )
        evaluator.add_event_handler(
            event_name=Events.EPOCH_COMPLETED,
            handler=checkpoint_handler_acc,
            to_save={
                'base_model': base_model.state_dict(),
                'classifier': classifier.state_dict(),
                'setops_model': setops_model.state_dict(),
            }
        )
        evaluator.add_event_handler(
            event_name=Events.EPOCH_COMPLETED,
            handler=checkpoint_handler_last,
            to_save={
                'base_model': base_model.state_dict(),
                'classifier': classifier.state_dict(),
                'setops_model': setops_model.state_dict(),
            }
        ) 
开发者ID:leokarlin,项目名称:LaSO,代码行数:42,代码来源:train_setops_stripped.py

示例10: __call__

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def __call__(self, engine: Engine, to_save: Mapping) -> None:

        if len(to_save) == 0:
            raise RuntimeError("No objects to checkpoint found.")

        self._check_objects(to_save, "state_dict")
        self.to_save = to_save
        super(ModelCheckpoint, self).__call__(engine) 
开发者ID:pytorch,项目名称:ignite,代码行数:10,代码来源:checkpoint.py

示例11: test_model_checkpoint_args_validation

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_model_checkpoint_args_validation(dirname):
    existing = os.path.join(dirname, "existing_dir")
    nonempty = os.path.join(dirname, "nonempty")

    os.makedirs(existing)
    os.makedirs(nonempty)

    with open(os.path.join(nonempty, "{}_name_0.pt".format(_PREFIX)), "w"):
        pass

    with pytest.raises(ValueError, match=r"with extension '.pt' are already present "):
        ModelCheckpoint(nonempty, _PREFIX)

    with pytest.raises(ValueError, match=r"Argument save_interval is deprecated and should be None"):
        ModelCheckpoint(existing, _PREFIX, save_interval=42)

    with pytest.raises(ValueError, match=r"Directory path '\S+' is not found"):
        ModelCheckpoint(os.path.join(dirname, "non_existing_dir"), _PREFIX, create_dir=False)

    with pytest.raises(ValueError, match=r"Argument save_as_state_dict is deprecated and should be True"):
        ModelCheckpoint(existing, _PREFIX, create_dir=False, save_as_state_dict=False)

    with pytest.raises(ValueError, match=r"If `score_name` is provided, then `score_function` "):
        ModelCheckpoint(existing, _PREFIX, create_dir=False, score_name="test")

    with pytest.raises(TypeError, match=r"global_step_transform should be a function"):
        ModelCheckpoint(existing, _PREFIX, create_dir=False, global_step_transform=1234)

    with pytest.warns(UserWarning, match=r"Argument archived is deprecated"):
        ModelCheckpoint(existing, _PREFIX, create_dir=False, archived=True)

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False)
    assert h.last_checkpoint is None
    with pytest.raises(RuntimeError, match=r"No objects to checkpoint found."):
        h(None, []) 
开发者ID:pytorch,项目名称:ignite,代码行数:37,代码来源:test_checkpoint.py

示例12: test_model_checkpoint_simple_recovery

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_model_checkpoint_simple_recovery(dirname):
    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False)
    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=1)

    model = DummyModel()
    to_save = {"model": model}
    h(engine, to_save)

    fname = h.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    assert loaded_objects == model.state_dict() 
开发者ID:pytorch,项目名称:ignite,代码行数:17,代码来源:test_checkpoint.py

示例13: test_with_state_dict

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_with_state_dict(dirname):
    def update_fn(_1, _2):
        pass

    engine = Engine(update_fn)
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)

    model = DummyModel()
    to_save = {"model": model}
    engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)
    engine.run([0], max_epochs=4)

    saved_model = os.path.join(dirname, os.listdir(dirname)[0])
    load_model = torch.load(saved_model)

    assert not isinstance(load_model, DummyModel)
    assert isinstance(load_model, dict)

    model_state_dict = model.state_dict()
    loaded_model_state_dict = load_model
    for key in model_state_dict.keys():
        assert key in loaded_model_state_dict

        model_value = model_state_dict[key]
        loaded_model_value = loaded_model_state_dict[key]

        assert model_value.numpy() == loaded_model_value.numpy() 
开发者ID:pytorch,项目名称:ignite,代码行数:29,代码来源:test_checkpoint.py

示例14: test_valid_state_dict_save

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_valid_state_dict_save(dirname):
    model = DummyModel()
    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    to_save = {"name": 42}
    with pytest.raises(TypeError, match=r"should have `state_dict` method"):
        h(engine, to_save)
    to_save = {"name": model}
    try:
        h(engine, to_save)
    except ValueError:
        pytest.fail("Unexpected ValueError") 
开发者ID:pytorch,项目名称:ignite,代码行数:17,代码来源:test_checkpoint.py

示例15: __init__

# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def __init__(
        self,
        to_save: Mapping,
        save_handler: Union[Callable, BaseSaveHandler],
        filename_prefix: str = "",
        score_function: Optional[Callable] = None,
        score_name: Optional[str] = None,
        n_saved: Optional[int] = 1,
        global_step_transform: Callable = None,
        archived: bool = False,
        include_self: bool = False,
    ):

        if to_save is not None:  # for compatibility with ModelCheckpoint
            if not isinstance(to_save, collections.Mapping):
                raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save)))

            if len(to_save) < 1:
                raise ValueError("No objects to checkpoint.")

            self._check_objects(to_save, "state_dict")

            if include_self:
                if not isinstance(to_save, collections.MutableMapping):
                    raise TypeError(
                        "If `include_self` is True, then `to_save` must be mutable, but given {}.".format(type(to_save))
                    )

                if "checkpointer" in to_save:
                    raise ValueError("Cannot have key 'checkpointer' if `include_self` is True: {}".format(to_save))

        if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)):
            raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler")

        if score_function is None and score_name is not None:
            raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.")

        if global_step_transform is not None and not callable(global_step_transform):
            raise TypeError(
                "global_step_transform should be a function, got {} instead.".format(type(global_step_transform))
            )
        if archived:
            warnings.warn("Argument archived is deprecated and will be removed in 0.5.0")

        self.to_save = to_save
        self._fname_prefix = filename_prefix + "_" if len(filename_prefix) > 0 else filename_prefix
        self.save_handler = save_handler
        self._score_function = score_function
        self._score_name = score_name
        self._n_saved = n_saved
        self._saved = []
        self._ext = ".pt"
        self.global_step_transform = global_step_transform
        self.include_self = include_self 
开发者ID:pytorch,项目名称:ignite,代码行数:56,代码来源:checkpoint.py


注:本文中的ignite.handlers.ModelCheckpoint方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。