当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch register_kl用法及代码示例


本文简要介绍python语言中 torch.distributions.kl.register_kl 的用法。

用法:

torch.distributions.kl.register_kl(type_p, type_q)

参数

  • type_p(type) -Distribution 的子类。

  • type_q(type) -Distribution 的子类。

使用 kl_divergence() 注册成对函数的装饰器。用法:

@register_kl(Normal, Normal)
def kl_normal_normal(p, q):
    # insert implementation here

查找返回按子类排序的最具体的(类型,类型)匹配。如果匹配不明确,则会引发 RuntimeWarning。例如解决模棱两可的情况:

@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...
@register_kl(DerivedP, BaseQ)
def kl_version2(p, q): ...

您应该注册第三个 most-specific 实现,例如:

register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributions.kl.register_kl。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。