當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch checkpoint_sequential用法及代碼示例


本文簡要介紹python語言中 torch.utils.checkpoint.checkpoint_sequential 的用法。

用法:

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)

參數

  • functions-A torch.nn.Sequential 或模塊或函數列表(包括模型)以順序運行。

  • segments-在模型中創建的塊數

  • input-輸入到functions 的張量

  • preserve_rng_state(bool,可選的,默認=真) -在每個檢查點期間省略存儲和恢複 RNG 狀態。

返回

*inputs 上按順序運行 functions 的輸出

用於檢查點順序模型的輔助函數。

順序模型按順序(按順序)執行模塊/函數列表。因此,我們可以將這樣的模型劃分為各個段,並對每個段進行檢查點。除最後一個段外,所有段都將以 torch.no_grad() 方式運行,即不存儲中間激活。每個檢查點段的輸入將被保存,以便在反向傳遞中重新運行該段。

請參閱 checkpoint() 了解檢查點的工作原理。

警告

檢查點當前僅支持 torch.autograd.backward() ,並且僅在其inputs 參數未通過時才支持。不支持 torch.autograd.grad()

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.utils.checkpoint.checkpoint_sequential。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。