当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.Dataset.from_generator用法及代码示例


用法

@staticmethod
from_generator(
    generator, output_types=None, output_shapes=None, args=None,
    output_signature=None, name=None
)

参数

  • generator 返回支持iter() 协议的对象的可调用对象。如果未指定args,则generator 必须不带任何参数;否则,它必须采用与 args 中的值一样多的参数。
  • output_types (可选。)tf.DType 对象的(嵌套)结构对应于 generator 产生的元素的每个组件。
  • output_shapes (可选。)tf.TensorShape 对象的(嵌套)结构对应于 generator 产生的元素的每个组件。
  • args (可选。)tf.Tensor 对象的元组将被评估并作为NumPy-array 参数传递给generator
  • output_signature (可选。)tf.TypeSpec 对象的(嵌套)结构对应于 generator 产生的元素的每个组件。
  • name (可选。)from_generator 使用的 tf.data 操作的名称。

返回

  • Dataset 一个Dataset

创建一个 Dataset,其元素由 generator 生成。 (不推荐使用的参数)

警告:不推荐使用某些参数:(output_shapes, output_types)。它们将在未来的版本中被删除。更新说明:改用output_signature

generator 参数必须是可调用对象,该对象返回支持 iter() 协议的对象(例如生成器函数)。

generator 生成的元素必须与给定的 output_signature 参数或给定的 output_types 和(可选)output_shapes 参数兼容,以指定者为准。

调用from_generator 的推荐方法是使用output_signature 参数。在这种情况下,将假定输出由具有由output_signature 参数中的tf.TypeSpec 对象定义的类、形状和类型的对象组成:

def gen():
  ragged_tensor = tf.ragged.constant([[1, 2], [3]])
  yield 42, ragged_tensor

dataset = tf.data.Dataset.from_generator(
     gen,
     output_signature=(
         tf.TensorSpec(shape=(), dtype=tf.int32),
         tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))

list(dataset.take(1))
[(<tf.Tensor:shape=(), dtype=int32, numpy=42>,
<tf.RaggedTensor [[1, 2], [3]]>)]

还有一种不推荐使用的方法来调用from_generator,可以单独使用output_types 参数或与output_shapes 参数一起使用。在这种情况下,函数的输出将假定由 tf.Tensor 对象组成,其类型由 output_types 定义,形状未知或由 output_shapes 定义。

注意:Dataset.from_generator() 的当前实现使用 tf.numpy_function 并继承相同的约束。特别是,它要求将数据集和迭代器相关操作放置在与调用 Dataset.from_generator() 的 Python 程序相同的进程中的设备上。 generator 的主体不会在 GraphDef 中序列化,如果您需要序列化模型并在不同的环境中恢复它,则不应使用此方法。

注意:如果 generator 依赖于可变全局变量或其他外部状态,请注意运行时可能会调用 generator 多次(以支持重复 Dataset )以及在调用 Dataset.from_generator() 和从生成器生产第一个元素。改变全局变量或外部状态可能会导致未定义的行为,我们建议您在调用 Dataset.from_generator() 之前显式缓存 generator 中的任何外部状态。

注意:虽然output_signature 参数可以产生Dataset 元素,但Dataset.from_generator() 的范围应限于无法通过tf.data 操作表达的逻辑。在生成器函数中使用 tf.data 操作是一种反模式,可能会导致内存增长。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.Dataset.from_generator。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。