当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.TFRecordDataset.scan用法及代码示例


用法

scan(
    initial_state, scan_func, name=None
)

参数

  • initial_state 张量的嵌套结构,表示累加器的初始状态。
  • scan_func (old_state, input_element) 映射到 (new_state, output_element) 的函数。它必须接受两个参数并返回一对张量的嵌套结构。 new_state 必须与 initial_state 的结构匹配。
  • name (可选。) tf.data 操作的名称。

返回

  • 一个Dataset

一种跨输入数据集扫描函数的转换。

此转换是 tf.data.Dataset.map 的有状态相对。除了将scan_func 映射到输入数据集的元素之外,scan() 还会累积一个或多个状态张量,其初始值为 initial_state

dataset = tf.data.Dataset.range(10)
initial_state = tf.constant(0, dtype=tf.int64)
scan_func = lambda state, i:(state + i, state + i)
dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func)
list(dataset.as_numpy_iterator())
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45]

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.TFRecordDataset.scan。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。