當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python mxnet.ndarray.contrib.while_loop用法及代碼示例

用法:

mxnet.ndarray.contrib.while_loop(cond, func, loop_vars, max_iterations=None)

參數

  • cond(a Python function.) - 循環條件。
  • func(a Python function.) - 循環體。
  • loop_vars(an NDArray or nested lists of NDArrays.) - 循環變量的初始值。
  • max_iterations(a python int.) - 最大迭代次數。

返回

  • outputs(an NDArray or nested lists of NDArrays) - 每一步的堆疊輸出
  • states(an NDArray or nested lists of NDArrays) - 最終狀態

使用用戶定義的計算和循環條件運行 while 循環。

該運算符模擬一個 while 循環,隻要滿足條件,它就會迭代地進行自定義計算。

loop_vars 是計算使用的 NDArray 列表。

cond 是用戶自定義函數,用作循環條件。它消耗 loop_vars ,並產生一個標量 MXNet NDArray,指示循環的終止。當cond 返回假(零)時,循環結束。 cond 是可變參數,它的簽名應該是 cond(*loop_vars) => NDArray

func 是用戶定義的函數,用作循環體。它還消耗 loop_vars ,並在每一步生成 step_outputnew_loop_vars 。在每個步驟中,step_output 應該包含相同的數字元素。通過所有步驟,step_output 的 i-th 元素應該具有相同的形狀和 dtype。此外,new_loop_vars 應包含與 loop_vars 相同數量的元素,並且相應的元素應具有相同的形狀和 dtype。 func 是可變參數,它的簽名應該是 func(*loop_vars) => (NDArray or nested List[NDArray] step_output, NDArray or nested List[NDArray] new_loop_vars)

max_iterations 是一個標量,用於定義允許的最大迭代次數。

此函數返回兩個列表。第一個列表的長度為 |step_output| ,其中 i-th 元素是所有步驟中 step_output 的所有 i-th 元素,沿軸 0 堆疊。第二個列表的長度為 |loop_vars| ,表示循環變量的最終狀態。

警告

目前,由於缺乏動態形狀推斷,第一個列表中所有 NDArray 的軸 0 為 max_iterations

警告

cond 永遠不滿足時,我們假設step_output 為空,因為它無法推斷。這與符號版本不同。

例子

>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64"))
>>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10)
>>> outputs
[
[[ 1]
[ 2]
[ 4]
[ 7]
[11]
[16]
[...]  # undefined value
[...]
[...]
[...]]
<NDArray 6x1 @cpu(0)>]
>>> states
[
[6]
<NDArray 1 @cpu(0)>,
[16]
<NDArray 1 @cpu(0)>]

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.ndarray.contrib.while_loop。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。