扫描从维度 0 的 elems
解压缩的张量列表。(不推荐使用的参数值)
用法
tf.scan(
fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, infer_shape=True, reverse=False, name=None
)
参数
-
fn
要执行的可调用对象。它接受两个参数。如果提供了一个,第一个将具有与initializer
相同的结构,否则它将具有与elems
相同的结构。第二个将具有与elems
相同的(可能是嵌套的)结构。如果提供了一个,它的输出必须与initializer
具有相同的结构,否则它必须与elems
具有相同的结构。 -
elems
张量或(可能是嵌套的)张量序列,每个张量都将沿其第一维展开。结果切片的嵌套序列将是fn
的第一个参数。 -
initializer
(可选)张量或(可能是嵌套的)张量序列、累加器的初始值以及fn
的预期输出类型。 -
parallel_iterations
(可选)允许并行运行的迭代次数。 -
back_prop
(可选)已弃用。 False 禁用对反向传播的支持。更喜欢使用tf.stop_gradient
。 -
swap_memory
(可选)True 启用GPU-CPU 内存交换。 -
infer_shape
(可选)False 禁用一致输出形状的测试。 -
reverse
(可选)True 从后到前扫描张量(而不是从前到后)。 -
name
(可选)返回张量的名称前缀。
返回
-
张量或(可能是嵌套的)张量序列。每个张量将应用
fn
的结果沿第一个维度从elems
解压缩的张量打包,以及前一个累加器值,从第一个到最后一个(或从最后一个到第一个,如果reverse=True
)。
抛出
-
TypeError
如果fn
不可调用或fn
和initializer
的输出结构不匹配。 -
ValueError
如果fn
和initializer
的输出长度不匹配。
警告:不推荐使用某些参数值:(back_prop=False)
。它们将在未来的版本中被删除。更新说明:back_prop=False 已弃用。考虑改用 tf.stop_gradient。代替:results = tf.scan(fn, elems, back_prop=False) 使用:results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))
scan
的最简单版本重复地将可调用的fn
应用于从第一个到最后一个元素的序列。元素由从维度 0 的 elems
解压缩的张量组成。可调用的 fn 将两个张量作为参数。第一个参数是从前面的 fn 调用计算的累加值,第二个参数是 elems
的当前位置的值。如果initializer
为None,则elems
必须至少包含一个元素,并且它的第一个元素用作初始值设定项。
假设 elems
被解压缩到 values
中,这是一个张量列表。结果张量的形状是 [len(values)] + fn(initializer, values[0]).shape
。如果 reverse=True,则为 fn(initializer, values[-1]).shape。
此方法还允许multi-arity elems
和累加器。如果elems
是一个(可能是嵌套的)张量列表或元组,那么这些张量中的每一个都必须具有匹配的第一个(解包)维度。 fn
的第二个参数必须与 elems
的结构匹配。
如果没有提供initializer
,则假设fn
的输出结构和dtypes与其输入相同;在这种情况下,fn
的第一个参数必须与 elems
的结构匹配。
如果提供了 initializer
,则 fn
的输出必须与 initializer
具有相同的结构;并且fn
的第一个参数必须匹配这个结构。
例如,如果 elems
是 (t1, [t2, t3])
并且 initializer
是 [i1, i2]
那么 python2
中 fn
的适当签名是:fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):
和 fn
必须返回一个列表 [acc_n1, acc_n2]
。 fn
的另一种正确签名以及在 python3
中工作的签名是: fn = lambda a, t:
,其中 a
和 t
对应于输入元组。
例子:
elems = np.array([1, 2, 3, 4, 5, 6])
sum = scan(lambda a, x:a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
sum = scan(lambda a, x:a + x, elems, reverse=True)
# sum == [21, 20, 18, 15, 11, 6]
elems = np.array([1, 2, 3, 4, 5, 6])
initializer = np.array(0)
sum_one = scan(
lambda a, x:x[0] - x[1] + a, (elems + 1, elems), initializer)
# sum_one == [1, 2, 3, 4, 5, 6]
elems = np.array([1, 0, 0, 0, 0, 0])
initializer = (np.array(0), np.array(1))
fibonaccis = scan(lambda a, _:(a[1], a[0] + a[1]), elems, initializer)
# fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
相关用法
- Python tf.scatter_nd用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.strings.substr用法及代码示例
- Python tf.strings.reduce_join用法及代码示例
- Python tf.sparse.cross用法及代码示例
- Python tf.sparse.mask用法及代码示例
- Python tf.strings.regex_full_match用法及代码示例
- Python tf.sparse.split用法及代码示例
- Python tf.strings.regex_replace用法及代码示例
- Python tf.signal.overlap_and_add用法及代码示例
- Python tf.strings.length用法及代码示例
- Python tf.strided_slice用法及代码示例
- Python tf.sparse.to_dense用法及代码示例
- Python tf.strings.bytes_split用法及代码示例
- Python tf.summary.text用法及代码示例
- Python tf.shape用法及代码示例
- Python tf.sparse.expand_dims用法及代码示例
- Python tf.signal.frame用法及代码示例
- Python tf.sparse.maximum用法及代码示例
- Python tf.signal.linear_to_mel_weight_matrix用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.scan。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。