當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.estimator.inputs.numpy_input_fn用法及代碼示例


返回將 numpy 數組的 dict 輸入模型的輸入函數。

用法

tf.compat.v1.estimator.inputs.numpy_input_fn(
    x, y=None, batch_size=128, num_epochs=1, shuffle=None, queue_capacity=1000,
    num_threads=1
)

參數

  • x numpy 數組對象或 numpy 數組對象的字典。如果是數組,則該數組將被視為單個特征。
  • y numpy 數組對象或 numpy 數組對象的字典。 None 如果不存在。
  • batch_size 整數,要返回的批次大小。
  • num_epochs 整數,迭代數據的時期數。如果None 將永遠運行。
  • shuffle 布爾值,如果 True 打亂隊列。避免在預測時洗牌。
  • queue_capacity 整數,要累積的隊列大小。
  • num_threads 整數,用於讀取和入隊的線程數。為了具有預測和可重複的讀取和入隊順序,例如在預測和評估模式下,num_threads 應該是 1。

返回

  • 函數,簽名為 ()->(dict of features , targets )

拋出

  • ValueError 如果 y 的形狀與 x 中的值的形狀不匹配(即,x 中的值具有相同的形狀)。
  • ValueError 如果 y 是字典時,xy 中都有重複鍵。
  • ValueError 如果 x 或 y 是空字典。
  • TypeError x 不是字典或數組。
  • ValueError 如果未提供 'shuffle' 或布爾值。

這將返回一個基於 numpy 數組的字典輸出 featurestargets 的函數。 dict featuresx 具有相同的鍵。如果y 是字典,則字典targetsy 具有相同的鍵。

例子:

age = np.arange(4) * 1.0
height = np.arange(32, 36)
x = {'age':age, 'height':height}
y = np.arange(-32, -28)

with tf.Session() as session:
  input_fn = numpy_io.numpy_input_fn(
      x, y, batch_size=2, shuffle=False, num_epochs=1)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.estimator.inputs.numpy_input_fn。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。