当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.estimator.inputs.numpy_input_fn用法及代码示例


返回将 numpy 数组的 dict 输入模型的输入函数。

用法

tf.compat.v1.estimator.inputs.numpy_input_fn(
    x, y=None, batch_size=128, num_epochs=1, shuffle=None, queue_capacity=1000,
    num_threads=1
)

参数

  • x numpy 数组对象或 numpy 数组对象的字典。如果是数组,则该数组将被视为单个特征。
  • y numpy 数组对象或 numpy 数组对象的字典。 None 如果不存在。
  • batch_size 整数,要返回的批次大小。
  • num_epochs 整数,迭代数据的时期数。如果None 将永远运行。
  • shuffle 布尔值,如果 True 打乱队列。避免在预测时洗牌。
  • queue_capacity 整数,要累积的队列大小。
  • num_threads 整数,用于读取和入队的线程数。为了具有预测和可重复的读取和入队顺序,例如在预测和评估模式下,num_threads 应该是 1。

返回

  • 函数,签名为 ()->(dict of features , targets )

抛出

  • ValueError 如果 y 的形状与 x 中的值的形状不匹配(即,x 中的值具有相同的形状)。
  • ValueError 如果 y 是字典时,xy 中都有重复键。
  • ValueError 如果 x 或 y 是空字典。
  • TypeError x 不是字典或数组。
  • ValueError 如果未提供 'shuffle' 或布尔值。

这将返回一个基于 numpy 数组的字典输出 featurestargets 的函数。 dict featuresx 具有相同的键。如果y 是字典,则字典targetsy 具有相同的键。

例子:

age = np.arange(4) * 1.0
height = np.arange(32, 36)
x = {'age':age, 'height':height}
y = np.arange(-32, -28)

with tf.Session() as session:
  input_fn = numpy_io.numpy_input_fn(
      x, y, batch_size=2, shuffle=False, num_epochs=1)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.estimator.inputs.numpy_input_fn。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。