本文整理汇总了Python中ignite.handlers.ModelCheckpoint方法的典型用法代码示例。如果您正苦于以下问题:Python handlers.ModelCheckpoint方法的具体用法?Python handlers.ModelCheckpoint怎么用?Python handlers.ModelCheckpoint使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类ignite.handlers
的用法示例。
在下文中一共展示了handlers.ModelCheckpoint方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_last_k
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_last_k(dirname):
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)
model = DummyModel()
to_save = {"model": model}
h(engine, to_save)
for i in range(1, 9):
engine.state.iteration = i
h(engine, to_save)
expected = ["{}_{}_{}.pt".format(_PREFIX, "model", i) for i in [7, 8]]
assert sorted(os.listdir(dirname)) == expected, "{} vs {}".format(sorted(os.listdir(dirname)), expected)
示例2: test_disabled_n_saved
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_disabled_n_saved(dirname):
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=None)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)
model = DummyModel()
to_save = {"model": model}
num_iters = 100
for i in range(num_iters):
engine.state.iteration = i
h(engine, to_save)
saved_files = sorted(os.listdir(dirname))
assert len(saved_files) == num_iters, "{}".format(saved_files)
expected = sorted(["{}_{}_{}.pt".format(_PREFIX, "model", i) for i in range(num_iters)])
assert saved_files == expected, "{} vs {}".format(saved_files, expected)
示例3: test_best_k
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_best_k(dirname):
scores = iter([1.2, -2.0, 3.1, -4.0])
def score_function(_):
return next(scores)
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)
model = DummyModel()
to_save = {"model": model}
for _ in range(4):
h(engine, to_save)
expected = ["{}_{}_{:.4f}.pt".format(_PREFIX, "model", i) for i in [1.2, 3.1]]
assert sorted(os.listdir(dirname)) == expected
示例4: test_best_k_with_suffix
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_best_k_with_suffix(dirname):
scores = [0.3456789, 0.1234, 0.4567, 0.134567]
scores_iter = iter(scores)
def score_function(engine):
return next(scores_iter)
h = ModelCheckpoint(
dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function, score_name="val_loss"
)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)
model = DummyModel()
to_save = {"model": model}
for _ in range(4):
engine.state.epoch += 1
h(engine, to_save)
expected = ["{}_{}_val_loss={:.4}.pt".format(_PREFIX, "model", scores[e - 1]) for e in [1, 3]]
assert sorted(os.listdir(dirname)) == expected
示例5: test_removes_each_score_at_most_once
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_removes_each_score_at_most_once(dirname):
scores = [0, 1, 1, 2, 3]
scores_iter = iter(scores)
def score_function(_):
return next(scores_iter)
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, score_function=score_function)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)
model = DummyModel()
to_save = {"model": model}
for _ in range(len(scores)):
h(engine, to_save)
# If a score was removed multiple times, the code above would have raise a
# FileNotFoundError. So this just tests the absence of such a failure
# without futher assertions.
示例6: test_load_checkpoint_with_different_num_classes
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_load_checkpoint_with_different_num_classes(dirname):
model = DummyPretrainedModel()
to_save_single_object = {"model": model}
trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
handler(trainer, to_save_single_object)
fname = handler.last_checkpoint
loaded_checkpoint = torch.load(fname)
to_load_single_object = {"pretrained_features": model.features}
with pytest.raises(RuntimeError):
Checkpoint.load_objects(to_load_single_object, loaded_checkpoint)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
Checkpoint.load_objects(to_load_single_object, loaded_checkpoint, strict=False, blah="blah")
loaded_weights = to_load_single_object["pretrained_features"].state_dict()["weight"]
assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
示例7: _test_tpu_saves_to_cpu
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def _test_tpu_saves_to_cpu(device, dirname):
torch.manual_seed(0)
h = ModelCheckpoint(dirname, _PREFIX)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=1)
model = DummyModel().to(device)
to_save = {"model": model}
h(engine, to_save)
idist.barrier()
fname = h.last_checkpoint
assert isinstance(fname, str)
assert os.path.join(dirname, _PREFIX) in fname
assert os.path.exists(fname)
loaded_objects = torch.load(fname)
assert loaded_objects == model.cpu().state_dict()
示例8: _create_checkpoint_handler
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def _create_checkpoint_handler(self):
return ModelCheckpoint(
self._model_save_location,
self._running_model_prefix,
score_function=self._score_function,
n_saved=1,
create_dir=True,
save_as_state_dict=True,
require_empty=False,
)
示例9: setup_checkpoint
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def setup_checkpoint(self, base_model, classifier, setops_model, evaluator):
"""Save checkpoints of the models."""
checkpoint_handler_acc = ModelCheckpoint(
self.results_path,
CKPT_PREFIX,
score_function=lambda eng: round(
(eng.state.metrics["fake class acc"] + eng.state.metrics["S class acc"] +
eng.state.metrics["I class acc"] + eng.state.metrics["U class acc"]) / 4,
3
),
score_name="val_acc",
n_saved=2,
require_empty=False
)
checkpoint_handler_last = ModelCheckpoint(
self.results_path,
CKPT_PREFIX,
save_interval=2,
n_saved=2,
require_empty=False
)
evaluator.add_event_handler(
event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler_acc,
to_save={
'base_model': base_model.state_dict(),
'classifier': classifier.state_dict(),
'setops_model': setops_model.state_dict(),
}
)
evaluator.add_event_handler(
event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler_last,
to_save={
'base_model': base_model.state_dict(),
'classifier': classifier.state_dict(),
'setops_model': setops_model.state_dict(),
}
)
示例10: __call__
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def __call__(self, engine: Engine, to_save: Mapping) -> None:
if len(to_save) == 0:
raise RuntimeError("No objects to checkpoint found.")
self._check_objects(to_save, "state_dict")
self.to_save = to_save
super(ModelCheckpoint, self).__call__(engine)
示例11: test_model_checkpoint_args_validation
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_model_checkpoint_args_validation(dirname):
existing = os.path.join(dirname, "existing_dir")
nonempty = os.path.join(dirname, "nonempty")
os.makedirs(existing)
os.makedirs(nonempty)
with open(os.path.join(nonempty, "{}_name_0.pt".format(_PREFIX)), "w"):
pass
with pytest.raises(ValueError, match=r"with extension '.pt' are already present "):
ModelCheckpoint(nonempty, _PREFIX)
with pytest.raises(ValueError, match=r"Argument save_interval is deprecated and should be None"):
ModelCheckpoint(existing, _PREFIX, save_interval=42)
with pytest.raises(ValueError, match=r"Directory path '\S+' is not found"):
ModelCheckpoint(os.path.join(dirname, "non_existing_dir"), _PREFIX, create_dir=False)
with pytest.raises(ValueError, match=r"Argument save_as_state_dict is deprecated and should be True"):
ModelCheckpoint(existing, _PREFIX, create_dir=False, save_as_state_dict=False)
with pytest.raises(ValueError, match=r"If `score_name` is provided, then `score_function` "):
ModelCheckpoint(existing, _PREFIX, create_dir=False, score_name="test")
with pytest.raises(TypeError, match=r"global_step_transform should be a function"):
ModelCheckpoint(existing, _PREFIX, create_dir=False, global_step_transform=1234)
with pytest.warns(UserWarning, match=r"Argument archived is deprecated"):
ModelCheckpoint(existing, _PREFIX, create_dir=False, archived=True)
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False)
assert h.last_checkpoint is None
with pytest.raises(RuntimeError, match=r"No objects to checkpoint found."):
h(None, [])
示例12: test_model_checkpoint_simple_recovery
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_model_checkpoint_simple_recovery(dirname):
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=1)
model = DummyModel()
to_save = {"model": model}
h(engine, to_save)
fname = h.last_checkpoint
assert isinstance(fname, str)
assert os.path.join(dirname, _PREFIX) in fname
assert os.path.exists(fname)
loaded_objects = torch.load(fname)
assert loaded_objects == model.state_dict()
示例13: test_with_state_dict
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_with_state_dict(dirname):
def update_fn(_1, _2):
pass
engine = Engine(update_fn)
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
model = DummyModel()
to_save = {"model": model}
engine.add_event_handler(Events.EPOCH_COMPLETED, handler, to_save)
engine.run([0], max_epochs=4)
saved_model = os.path.join(dirname, os.listdir(dirname)[0])
load_model = torch.load(saved_model)
assert not isinstance(load_model, DummyModel)
assert isinstance(load_model, dict)
model_state_dict = model.state_dict()
loaded_model_state_dict = load_model
for key in model_state_dict.keys():
assert key in loaded_model_state_dict
model_value = model_state_dict[key]
loaded_model_value = loaded_model_state_dict[key]
assert model_value.numpy() == loaded_model_value.numpy()
示例14: test_valid_state_dict_save
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def test_valid_state_dict_save(dirname):
model = DummyModel()
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
engine = Engine(lambda e, b: None)
engine.state = State(epoch=0, iteration=0)
to_save = {"name": 42}
with pytest.raises(TypeError, match=r"should have `state_dict` method"):
h(engine, to_save)
to_save = {"name": model}
try:
h(engine, to_save)
except ValueError:
pytest.fail("Unexpected ValueError")
示例15: __init__
# 需要导入模块: from ignite import handlers [as 别名]
# 或者: from ignite.handlers import ModelCheckpoint [as 别名]
def __init__(
self,
to_save: Mapping,
save_handler: Union[Callable, BaseSaveHandler],
filename_prefix: str = "",
score_function: Optional[Callable] = None,
score_name: Optional[str] = None,
n_saved: Optional[int] = 1,
global_step_transform: Callable = None,
archived: bool = False,
include_self: bool = False,
):
if to_save is not None: # for compatibility with ModelCheckpoint
if not isinstance(to_save, collections.Mapping):
raise TypeError("Argument `to_save` should be a dictionary, but given {}".format(type(to_save)))
if len(to_save) < 1:
raise ValueError("No objects to checkpoint.")
self._check_objects(to_save, "state_dict")
if include_self:
if not isinstance(to_save, collections.MutableMapping):
raise TypeError(
"If `include_self` is True, then `to_save` must be mutable, but given {}.".format(type(to_save))
)
if "checkpointer" in to_save:
raise ValueError("Cannot have key 'checkpointer' if `include_self` is True: {}".format(to_save))
if not (callable(save_handler) or isinstance(save_handler, BaseSaveHandler)):
raise TypeError("Argument `save_handler` should be callable or inherit from BaseSaveHandler")
if score_function is None and score_name is not None:
raise ValueError("If `score_name` is provided, then `score_function` " "should be also provided.")
if global_step_transform is not None and not callable(global_step_transform):
raise TypeError(
"global_step_transform should be a function, got {} instead.".format(type(global_step_transform))
)
if archived:
warnings.warn("Argument archived is deprecated and will be removed in 0.5.0")
self.to_save = to_save
self._fname_prefix = filename_prefix + "_" if len(filename_prefix) > 0 else filename_prefix
self.save_handler = save_handler
self._score_function = score_function
self._score_name = score_name
self._n_saved = n_saved
self._saved = []
self._ext = ".pt"
self.global_step_transform = global_step_transform
self.include_self = include_self