当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch SGD用法及代码示例


本文简要介绍python语言中 torch.optim.SGD 的用法。

用法:

class torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

参数

  • params(可迭代的) -可迭代的参数以优化或 dicts 定义参数组

  • lr(float) -学习率

  • momentum(float,可选的) -动量因子(默认值:0)

  • weight_decay(float,可选的) -权重衰减(L2 惩罚)(默认值:0)

  • dampening(float,可选的) -动量阻尼(默认值:0)

  • nesterov(bool,可选的) -启用 Nesterov 动量(默认值:False)

实现随机梯度下降(可选用动量)。

Nesterov 动量基于公式关于初始化和动量在深度学习中的重要性.

示例

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

注意

使用 Momentum/Nesterov 实现的 SGD 与 Sutskever 等人略有不同。人。以及其他一些框架中的实现。

考虑到 Momentum 的具体情况,更新可以写为

其中 分别表示参数、梯度、速度和动量。

这与 Sutskever 等人形成鲜明对比。人。和其他使用表格更新的框架

Nesterov 版本进行了类似的修改。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.optim.SGD。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。