當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.compat.v1.data.Iterator.from_structure用法及代碼示例

用法

@staticmethod
from_structure(
    output_types, output_shapes=None, shared_name=None, output_classes=None
)

參數

  • output_types tf.DType 對象的(嵌套)結構,對應於該數據集元素的每個組件。
  • output_shapes (可選。)tf.TensorShape 對象的(嵌套)結構對應於此數據集元素的每個組件。如果省略,每個組件將具有不受約束的形狀。
  • shared_name (可選。)如果非空,此迭代器將在共享相同設備的多個會話中以給定名稱共享(例如,當使用遠程服務器時)。
  • output_classes (可選。)Python type 對象的(嵌套)結構,對應於此迭代器元素的每個組件。如果省略,則假定每個組件的類型為 tf.Tensor

返回

  • 一個 Iterator

拋出

  • TypeError 如果output_shapesoutput_types的結構不同。

使用給定的結構創建一個新的、未初始化的 Iterator

此iterator-constructing 方法可用於創建可與許多不同數據集重用的迭代器。

返回的迭代器未綁定到特定數據集,並且沒有 initializer 。要初始化迭代器,請運行 Iterator.make_initializer(dataset) 返回的操作。

下麵是一個例子

iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)

dataset_evens = dataset_range.filter(lambda x:x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)

# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())

# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
  # Initialize the iterator to `dataset_range`
  sess.run(range_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

  # Initialize the iterator to `dataset_evens`
  sess.run(evens_initializer)
  while True:
    try:
      pred, loss_val = sess.run([prediction, loss])
    except tf.errors.OutOfRangeError:
      break

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.data.Iterator.from_structure。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。