當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch register_kl用法及代碼示例

本文簡要介紹python語言中 torch.distributions.kl.register_kl 的用法。

用法:

torch.distributions.kl.register_kl(type_p, type_q)

參數

  • type_p(type) -Distribution 的子類。

  • type_q(type) -Distribution 的子類。

使用 kl_divergence() 注冊成對函數的裝飾器。用法:

@register_kl(Normal, Normal)
def kl_normal_normal(p, q):
    # insert implementation here

查找返回按子類排序的最具體的(類型,類型)匹配。如果匹配不明確,則會引發 RuntimeWarning。例如解決模棱兩可的情況:

@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...
@register_kl(DerivedP, BaseQ)
def kl_version2(p, q): ...

您應該注冊第三個 most-specific 實現,例如:

register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributions.kl.register_kl。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。