當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch cached用法及代碼示例


本文簡要介紹python語言中 torch.nn.utils.parametrize.cached 的用法。

用法:

torch.nn.utils.parametrize.cached()

register_parametrization() 注冊的參數化中啟用緩存係統的上下文管理器。

當上下文管理器處於活動狀態時,參數化對象的值在第一次需要時被計算和緩存。離開上下文管理器時,緩存的值將被丟棄。

這在前向傳遞中多次使用參數化參數時很有用。這方麵的一個例子是在參數化 RNN 的循環內核或共享權重時。

激活緩存的最簡單方法是包裝神經網絡的前向傳遞

import torch.nn.utils.parametrize as P
...
with P.cached():
    output = model(inputs)

在訓練和評估中。也可以包裝使用數倍參數化張量的模塊部分。例如,帶有參數化循環內核的 RNN 循環:

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.parametrize.cached。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。