当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.ndarray.contrib.while_loop用法及代码示例


用法:

mxnet.ndarray.contrib.while_loop(cond, func, loop_vars, max_iterations=None)

参数

  • cond(a Python function.) - 循环条件。
  • func(a Python function.) - 循环体。
  • loop_vars(an NDArray or nested lists of NDArrays.) - 循环变量的初始值。
  • max_iterations(a python int.) - 最大迭代次数。

返回

  • outputs(an NDArray or nested lists of NDArrays) - 每一步的堆叠输出
  • states(an NDArray or nested lists of NDArrays) - 最终状态

使用用户定义的计算和循环条件运行 while 循环。

该运算符模拟一个 while 循环,只要满足条件,它就会迭代地进行自定义计算。

loop_vars 是计算使用的 NDArray 列表。

cond 是用户自定义函数,用作循环条件。它消耗 loop_vars ,并产生一个标量 MXNet NDArray,指示循环的终止。当cond 返回假(零)时,循环结束。 cond 是可变参数,它的签名应该是 cond(*loop_vars) => NDArray

func 是用户定义的函数,用作循环体。它还消耗 loop_vars ,并在每一步生成 step_outputnew_loop_vars 。在每个步骤中,step_output 应该包含相同的数字元素。通过所有步骤,step_output 的 i-th 元素应该具有相同的形状和 dtype。此外,new_loop_vars 应包含与 loop_vars 相同数量的元素,并且相应的元素应具有相同的形状和 dtype。 func 是可变参数,它的签名应该是 func(*loop_vars) => (NDArray or nested List[NDArray] step_output, NDArray or nested List[NDArray] new_loop_vars)

max_iterations 是一个标量,用于定义允许的最大迭代次数。

此函数返回两个列表。第一个列表的长度为 |step_output| ,其中 i-th 元素是所有步骤中 step_output 的所有 i-th 元素,沿轴 0 堆叠。第二个列表的长度为 |loop_vars| ,表示循环变量的最终状态。

警告

目前,由于缺乏动态形状推断,第一个列表中所有 NDArray 的轴 0 为 max_iterations

警告

cond 永远不满足时,我们假设step_output 为空,因为它无法推断。这与符号版本不同。

例子

>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.nd.array([0], dtype="int64"), mx.nd.array([1], dtype="int64"))
>>> outputs, states = mx.nd.contrib.while_loop(cond, func, loop_vars, max_iterations=10)
>>> outputs
[
[[ 1]
[ 2]
[ 4]
[ 7]
[11]
[16]
[...]  # undefined value
[...]
[...]
[...]]
<NDArray 6x1 @cpu(0)>]
>>> states
[
[6]
<NDArray 1 @cpu(0)>,
[16]
<NDArray 1 @cpu(0)>]

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.ndarray.contrib.while_loop。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。