当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.contrib.while_loop用法及代码示例


用法:

mxnet.symbol.contrib.while_loop(cond, func, loop_vars, max_iterations=None, name='while_loop')

参数

  • cond(a Python function.) - 循环条件。
  • func(a Python function.) - 循环体。
  • loop_vars(a Symbol or nested lists of Symbol.) - 循环变量的初始值。
  • max_iterations(a python int.) - 最大迭代次数。

返回

  • outputs(a Symbol or nested lists of Symbols) - 每一步的堆叠输出
  • states(a Symbol or nested lists of Symbols) - 最终状态

使用用户定义的计算和循环条件运行 while 循环。

该运算符模拟一个 while 循环,只要满足条件,它就会迭代地进行自定义计算。

loop_vars 是一个符号或计算使用的符号的嵌套列表。

cond 是用户自定义函数,用作循环条件。它消耗 loop_vars ,并产生一个标量 MXNet 符号,指示循环的终止。当cond 返回假(零)时,循环结束。 cond 是可变参数,它的签名应该是 cond(*loop_vars) => Symbol

func 是用户定义的函数,用作循环体。它还消耗 loop_vars ,并在每一步生成 step_outputnew_loop_vars 。在每个步骤中,step_output 应该包含相同的数字元素。通过所有步骤,step_output 的 i-th 元素应该具有相同的形状和 dtype。此外,new_loop_vars 应包含与 loop_vars 相同数量的元素,并且相应的元素应具有相同的形状和 dtype。 func 是可变参数,它的签名应该是 func(*loop_vars) => (Symbol or nested List[Symbol] step_output, Symbol or nested List[Symbol] new_loop_vars)

max_iterations 是一个标量,用于定义允许的最大迭代次数。

此函数返回两个列表。第一个列表的长度为 |step_output| ,其中 i-th 元素是所有步骤中 step_output 的所有 i-th 元素,沿轴 0 堆叠。第二个列表的长度为 |loop_vars| ,表示循环变量的最终状态。

警告

目前,由于缺乏动态形状推断,第一个列表中所有符号的轴 0 为 max_iterations

警告

即使 cond 永远不会满足,while_loop 也会返回一个输出列表,其中包含推断的 dtype 和 shape。这与 Symbol 版本不同,在这种情况下 step_outputs 被假定为一个空列表。

例子

>>> cond = lambda i, s: i <= 5
>>> func = lambda i, s: ([i + s], [i + 1, s + i])
>>> loop_vars = (mx.sym.var('i'), mx.sym.var('s'))
>>> outputs, states = mx.sym.contrib.while_loop(cond, func, loop_vars, max_iterations=10)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.contrib.while_loop。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。