当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.scan用法及代码示例


扫描从维度 0 的 elems 解压缩的张量列表。(不推荐使用的参数值)

用法

tf.scan(
    fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
    swap_memory=False, infer_shape=True, reverse=False, name=None
)

参数

  • fn 要执行的可调用对象。它接受两个参数。如果提供了一个,第一个将具有与 initializer 相同的结构,否则它将具有与 elems 相同的结构。第二个将具有与 elems 相同的(可能是嵌套的)结构。如果提供了一个,它的输出必须与 initializer 具有相同的结构,否则它必须与 elems 具有相同的结构。
  • elems 张量或(可能是嵌套的)张量序列,每个张量都将沿其第一维展开。结果切片的嵌套序列将是 fn 的第一个参数。
  • initializer (可选)张量或(可能是嵌套的)张量序列、累加器的初始值以及 fn 的预期输出类型。
  • parallel_iterations (可选)允许并行运行的迭代次数。
  • back_prop (可选)已弃用。 False 禁用对反向传播的支持。更喜欢使用tf.stop_gradient
  • swap_memory (可选)True 启用GPU-CPU 内存交换。
  • infer_shape (可选)False 禁用一致输出形状的测试。
  • reverse (可选)True 从后到前扫描张量(而不是从前到后)。
  • name (可选)返回张量的名称前缀。

返回

  • 张量或(可能是嵌套的)张量序列。每个张量将应用 fn 的结果沿第一个维度从 elems 解压缩的张量打包,以及前一个累加器值,从第一个到最后一个(或从最后一个到第一个,如果 reverse=True )。

抛出

  • TypeError 如果fn 不可调用或fninitializer 的输出结构不匹配。
  • ValueError 如果fninitializer 的输出长度不匹配。

警告:不推荐使用某些参数值:(back_prop=False)。它们将在未来的版本中被删除。更新说明:back_prop=False 已弃用。考虑改用 tf.stop_gradient。代替:results = tf.scan(fn, elems, back_prop=False) 使用:results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))

scan 的最简单版本重复地将可调用的fn 应用于从第一个到最后一个元素的序列。元素由从维度 0 的 elems 解压缩的张量组成。可调用的 fn 将两个张量作为参数。第一个参数是从前面的 fn 调用计算的累加值,第二个参数是 elems 的当前位置的值。如果initializer 为None,则elems 必须至少包含一个元素,并且它的第一个元素用作初始值设定项。

假设 elems 被解压缩到 values 中,这是一个张量列表。结果张量的形状是 [len(values)] + fn(initializer, values[0]).shape 。如果 reverse=True,则为 fn(initializer, values[-1]).shape。

此方法还允许multi-arity elems 和累加器。如果elems 是一个(可能是嵌套的)张量列表或元组,那么这些张量中的每一个都必须具有匹配的第一个(解包)维度。 fn 的第二个参数必须与 elems 的结构匹配。

如果没有提供initializer,则假设fn的输出结构和dtypes与其输入相同;在这种情况下,fn 的第一个参数必须与 elems 的结构匹配。

如果提供了 initializer,则 fn 的输出必须与 initializer 具有相同的结构;并且fn 的第一个参数必须匹配这个结构。

例如,如果 elems(t1, [t2, t3]) 并且 initializer[i1, i2] 那么 python2fn 的适当签名是:fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):fn 必须返回一个列表 [acc_n1, acc_n2]fn 的另一种正确签名以及在 python3 中工作的签名是: fn = lambda a, t: ,其中 at 对应于输入元组。

例子:

elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x:a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
sum = scan(lambda a, x:a + x, elems, reverse=True)
# sum == [21, 20, 18, 15, 11, 6]
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(
    lambda a, x:x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]
elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _:(a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.scan。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。