当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.TFRecordDataset.batch用法及代码示例


用法

batch(
    batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None,
    name=None
)

参数

  • batch_size tf.int64 标量 tf.Tensor ,表示要在单个批次中组合的此数据集的连续元素的数量。
  • drop_remainder (可选。)一个 tf.bool 标量 tf.Tensor ,表示在最后一批少于 batch_size 元素的情况下是否应删除它;默认行为是不丢弃较小的批次。
  • num_parallel_calls (可选。)tf.int64 标量 tf.Tensor ,表示要并行异步计算的批次数。如果未指定,批次将按顺序计算。如果使用值tf.data.AUTOTUNE,则并行调用的数量根据可用资源动态设置。
  • deterministic (可选。)指定 num_parallel_calls 时,如果指定了此布尔值( TrueFalse ),它将控制转换生成元素的顺序。如果设置为 False ,则允许转换产生无序元素,以用确定性换取性能。如果未指定,则 tf.data.Options.deterministic 选项(默认为 True)控制行为。
  • name (可选。) tf.data 操作的名称。

返回

  • Dataset 一个Dataset

将此数据集的连续元素组合成批次。

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]

结果元素的组件将有一个额外的外部维度,它将是batch_size(或者如果batch_size没有将输入元素N的数量除以drop_remainder是最后一个元素,则为N % batch_size,并且drop_remainderFalse)。如果您的程序依赖于具有相同外部尺寸的批次,则应将 drop_remainder 参数设置为 True 以防止生成较小的批次。

注意:如果您的程序要求数据具有静态已知的形状(例如,使用 XLA 时),您应该使用 drop_remainder=True 。如果没有drop_remainder=True,输出数据集的形状将具有未知的前导维度,因为最终批次可能较小。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.TFRecordDataset.batch。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。