当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch cached用法及代码示例


本文简要介绍python语言中 torch.nn.utils.parametrize.cached 的用法。

用法:

torch.nn.utils.parametrize.cached()

register_parametrization() 注册的参数化中启用缓存系统的上下文管理器。

当上下文管理器处于活动状态时,参数化对象的值在第一次需要时被计算和缓存。离开上下文管理器时,缓存的值将被丢弃。

这在前向传递中多次使用参数化参数时很有用。这方面的一个例子是在参数化 RNN 的循环内核或共享权重时。

激活缓存的最简单方法是包装神经网络的前向传递

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在训练和评估中。也可以包装使用数倍参数化张量的模块部分。例如,带有参数化循环内核的 RNN 循环:

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.parametrize.cached。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。