当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch lazy_apply用法及代码示例


本文简要介绍python语言中 torchrec.modules.lazy_extension.lazy_apply 的用法。

用法:

torchrec.modules.lazy_extension.lazy_apply(module: torch.nn.modules.module.Module, fn: Callable[[torch.nn.modules.module.Module], None]) → torch.nn.modules.module.Module

参数

  • module(火炬.nn.模块) -递归应用fn 的模块。

  • fn(可调用[[火炬.nn.模块],None]) - 要附加的函数module然后应用于每个子模块modulemodule本身。

返回

module 附有 fn

返回类型

火炬.nn.模块

将函数附加到模块,该函数将递归地应用于模块的每个子模块(由 .children() 返回)以及模块本身在第一次前向传递之后(即在所有子模块和参数都已初始化之后)。

典型用途包括初始化惰性模块(即从 LazyModuleMixin 继承的模块)的参数的数值。

注意

lazy_apply() 可用于惰性和非惰性模块。

例子:

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == torch.nn.LazyLinear:
        m.weight.fill_(1.0)
        print(m.weight)

linear = torch.nn.LazyLinear(2)
lazy_apply(linear, init_weights)  # doesn't run `init_weights` immediately
input = torch.randn(2, 10)
linear(input)  # runs `init_weights` only once, right after first forward pass

seq = torch.nn.Sequential(torch.nn.LazyLinear(2), torch.nn.LazyLinear(2))
lazy_apply(seq, init_weights)  # doesn't run `init_weights` immediately
input = torch.randn(2, 10)
seq(input)  # runs `init_weights` only once, right after first forward pass

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.modules.lazy_extension.lazy_apply。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。