当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python PyTorch weight_norm用法及代码示例

本文简要介绍python语言中 torch.nn.utils.weight_norm 的用法。

用法:

torch.nn.utils.weight_norm(module, name='weight', dim=0)

参数

  • module(torch.nn.Module) -包含模块

  • name(str,可选的) -权重参数名称

  • dim(int,可选的) -计算范数的维度

返回

带有重量标准挂钩的原始模块

将权重归一化应用于给定模块中的参数。

权重归一化是将权重张量的大小与其方向解耦的重新参数化。这将 name 指定的参数(例如 'weight' )替换为两个参数:一个指定幅度(例如 'weight_g' ),另一个指定方向(例如 'weight_v' )。权重归一化是通过一个钩子实现的,该钩子在每次 forward() 调用之前从大小和方向重新计算权重张量。

默认情况下,使用 dim=0 ,每个输出通道/平面独立计算范数。要计算整个权重张量的范数,请使用 dim=None

https://arxiv.org/abs/1602.07868

例子:

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.weight_norm。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。