当前位置: 首页>>技术问答>>正文


在PyTorch中保存训练模型的最佳方法?

qingchuanTR 技术问答 , , , 去评论

我正在寻找在PyTorch中保存训练模型的不同方法。到目前为止,发现了两种选择。

  1. torch.save()用于保存模型,torch.load()用于加载模型。

  2. model.state_dict()用于保存训练模型,model.load_state_dict()用于加载保存的模型。

我已经遇到过这个discussion,其中推荐方法2优于方法1。

我的问题是,为什么第二种方法更受欢迎?是否因为torch.nn模块具有这两个功能所以鼓励使用它们?

最佳解决思路

我在github repo上找到了this page,我只是在这里原样粘贴复制内容过来:


保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)保存并仅加载模型参数:

torch.save(the_model.state_dict(), PATH)

然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

然后:

the_model = torch.load(PATH)

但是在这种情况下,序列化数据绑定到特定的类和使用的确切目录结构,因此当在其他项目中使用时,或者在一些严重的重构之后,它可能以各种奇怪的方式中断活出问题。

次佳解决思路

这取决于你想做什么。

案例#1:保存模型以自行使用它进行预测:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为通常有BatchNormDropout图层,默认情况下在构造时处于训练模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便以后恢复训练:如果您需要继续训练您将要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器,迭代轮次(epochs),分数等相关的状态。您可以这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢复训练,您可以执行以下操作:state = torch.load(filepath),然后,恢复每个对象的状态,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,因此在加载时恢复状态后请勿调用model.eval()

案例#3:其他人无法访问您的代码使用的模型:在Tensorflow中,您可以创建一个.pb文件,该文件定义了模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve时。在Pytorch中执行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式不太健壮,因为pytorch已经经历了很多变化,我们不太推荐这种方式。

pytorch

参考资料

本文由《纯净的天空》出品。文章地址: https://vimsky.com/article/3884.html,未经允许,请勿转载。