我正在寻找在PyTorch中保存训练模型的不同方法。到目前为止,发现了两种选择。
-
torch.save()用于保存模型,torch.load()用于加载模型。
-
model.state_dict()用于保存训练模型,model.load_state_dict()用于加载保存的模型。
我已经遇到过这个discussion,其中推荐方法2优于方法1。
我的问题是,为什么第二种方法更受欢迎?是否因为torch.nn模块具有这两个功能所以鼓励使用它们?
最佳解决思路
我在github repo上找到了this page,我只是在这里原样粘贴复制内容过来:
保存模型的推荐方法
序列化和恢复模型有两种主要方法。
第一个(推荐)保存并仅加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二个保存并加载整个模型:
torch.save(the_model, PATH)
然后:
the_model = torch.load(PATH)
但是在这种情况下,序列化数据绑定到特定的类和使用的确切目录结构,因此当在其他项目中使用时,或者在一些严重的重构之后,它可能以各种奇怪的方式中断活出问题。
次佳解决思路
这取决于你想做什么。
案例#1:保存模型以自行使用它进行预测:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为通常有BatchNorm
和Dropout
图层,默认情况下在构造时处于训练模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例#2:保存模型以便以后恢复训练:如果您需要继续训练您将要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器,迭代轮次(epochs),分数等相关的状态。您可以这样做:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
要恢复训练,您可以执行以下操作:state = torch.load(filepath)
,然后,恢复每个对象的状态,如下所示:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
由于您正在恢复训练,因此在加载时恢复状态后请勿调用model.eval()
。
案例#3:其他人无法访问您的代码使用的模型:在Tensorflow中,您可以创建一个.pb
文件,该文件定义了模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve
时。在Pytorch中执行此操作的等效方法是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
这种方式不太健壮,因为pytorch已经经历了很多变化,我们不太推荐这种方式。