当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch Module.register_buffer用法及代码示例


本文简要介绍python语言中 torch.nn.Module.register_buffer 的用法。

用法:

register_buffer(name, tensor, persistent=True)

参数

  • name(string) -缓冲区的名称。可以使用给定名称从此模块访问缓冲区

  • tensor(Tensor或者None) -要注册的缓冲区。如果None,然后在缓冲区上运行的操作,例如cuda, 被忽略。如果None,缓冲区是不是包含在模块的state_dict.

  • persistent(bool) -缓冲区是否是该模块state_dict的一部分。

向模块添加缓冲区。

这通常用于注册不应被视为模型参数的缓冲区。例如,BatchNorm 的 running_mean 不是参数,而是模块状态的一部分。默认情况下,缓冲区是持久的,并且将与参数一起保存。可以通过将 persistent 设置为 False 来更改此行为。持久缓冲区和非持久缓冲区之间的唯一区别是后者不会成为该模块 state_dict 的一部分。

可以使用给定名称将缓冲区作为属性访问。

例子:

>>> self.register_buffer('running_mean', torch.zeros(num_features))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.Module.register_buffer。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。