当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python tf.compat.v1.data.Iterator.from_structure用法及代码示例

用法

@staticmethod
from_structure(
    output_types, output_shapes=None, shared_name=None, output_classes=None
)

参数

  • output_types tf.DType 对象的(嵌套)结构,对应于该数据集元素的每个组件。
  • output_shapes (可选。)tf.TensorShape 对象的(嵌套)结构对应于此数据集元素的每个组件。如果省略,每个组件将具有不受约束的形状。
  • shared_name (可选。)如果非空,此迭代器将在共享相同设备的多个会话中以给定名称共享(例如,当使用远程服务器时)。
  • output_classes (可选。)Python type 对象的(嵌套)结构,对应于此迭代器元素的每个组件。如果省略,则假定每个组件的类型为 tf.Tensor

返回

  • 一个 Iterator

抛出

  • TypeError 如果output_shapesoutput_types的结构不同。

使用给定的结构创建一个新的、未初始化的 Iterator

此iterator-constructing 方法可用于创建可与许多不同数据集重用的迭代器。

返回的迭代器未绑定到特定数据集,并且没有 initializer 。要初始化迭代器,请运行 Iterator.make_initializer(dataset) 返回的操作。

下面是一个例子

iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)

dataset_evens = dataset_range.filter(lambda x:x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)

# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())

# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
  # Initialize the iterator to `dataset_range`
  sess.run(range_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

  # Initialize the iterator to `dataset_evens`
  sess.run(evens_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.data.Iterator.from_structure。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。