本文整理汇总了Python中maskrcnn_benchmark.utils.checkpoint.Checkpointer方法的典型用法代码示例。如果您正苦于以下问题:Python checkpoint.Checkpointer方法的具体用法?Python checkpoint.Checkpointer怎么用?Python checkpoint.Checkpointer使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类maskrcnn_benchmark.utils.checkpoint
的用法示例。
在下文中一共展示了checkpoint.Checkpointer方法的2个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: test_from_last_checkpoint_model
# 需要导入模块: from maskrcnn_benchmark.utils import checkpoint [as 别名]
# 或者: from maskrcnn_benchmark.utils.checkpoint import Checkpointer [as 别名]
def test_from_last_checkpoint_model(self):
# test that loading works even if they differ by a prefix
for trained_model, fresh_model in [
(self.create_model(), self.create_model()),
(nn.DataParallel(self.create_model()), self.create_model()),
(self.create_model(), nn.DataParallel(self.create_model())),
(
nn.DataParallel(self.create_model()),
nn.DataParallel(self.create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
checkpointer.save("checkpoint_file")
# in the same folder
fresh_checkpointer = Checkpointer(fresh_model, save_dir=f)
self.assertTrue(fresh_checkpointer.has_checkpoint())
self.assertEqual(
fresh_checkpointer.get_checkpoint_file(),
os.path.join(f, "checkpoint_file.pth"),
)
_ = fresh_checkpointer.load()
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
self.assertFalse(id(trained_p) == id(loaded_p))
# same content
self.assertTrue(trained_p.equal(loaded_p))
示例2: test_from_name_file_model
# 需要导入模块: from maskrcnn_benchmark.utils import checkpoint [as 别名]
# 或者: from maskrcnn_benchmark.utils.checkpoint import Checkpointer [as 别名]
def test_from_name_file_model(self):
# test that loading works even if they differ by a prefix
for trained_model, fresh_model in [
(self.create_model(), self.create_model()),
(nn.DataParallel(self.create_model()), self.create_model()),
(self.create_model(), nn.DataParallel(self.create_model())),
(
nn.DataParallel(self.create_model()),
nn.DataParallel(self.create_model()),
),
]:
with TemporaryDirectory() as f:
checkpointer = Checkpointer(
trained_model, save_dir=f, save_to_disk=True
)
checkpointer.save("checkpoint_file")
# on different folders
with TemporaryDirectory() as g:
fresh_checkpointer = Checkpointer(fresh_model, save_dir=g)
self.assertFalse(fresh_checkpointer.has_checkpoint())
self.assertEqual(fresh_checkpointer.get_checkpoint_file(), "")
_ = fresh_checkpointer.load(os.path.join(f, "checkpoint_file.pth"))
for trained_p, loaded_p in zip(
trained_model.parameters(), fresh_model.parameters()
):
# different tensor references
self.assertFalse(id(trained_p) == id(loaded_p))
# same content
self.assertTrue(trained_p.equal(loaded_p))