当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.experimental.dense_to_ragged_batch用法及代码示例


将参差不齐的元素批量转换为 tf.RaggedTensor 的转换。

用法

tf.data.experimental.dense_to_ragged_batch(
    batch_size, drop_remainder=False, row_splits_dtype=tf.dtypes.int64
)

参数

  • batch_size tf.int64 标量 tf.Tensor ,表示要在单个批次中组合的此数据集的连续元素的数量。
  • drop_remainder (可选。)一个 tf.bool 标量 tf.Tensor ,表示在最后一批少于 batch_size 元素的情况下是否应删除它;默认行为是不丢弃较小的批次。
  • row_splits_dtype 应该用于任何新不规则张量的 row_splits 的 dtype。现有 tf.RaggedTensor 元素的 row_splits dtype 未更改。

返回

  • Dataset 一个Dataset

此转换将输入数据集的多个连续元素组合成一个元素。

tf.data.Dataset.batch 一样,结果元素的组件将有一个额外的外部维度,即 batch_size (如果 batch_size 不均匀地划分输入元素的数量 N ,则为最后一个元素的 N % batch_size 并且drop_remainderFalse )。如果您的程序依赖于具有相同外部尺寸的批次,则应将 drop_remainder 参数设置为 True 以防止生成较小的批次。

tf.data.Dataset.batch 不同,要批处理的输入元素可能具有不同的形状:

  • 如果输入元素是一个 tf.Tensor,其静态 tf.TensorShape 已完全定义,则它按正常方式进行批处理。
  • 如果输入元素是 tf.Tensor,其静态 tf.TensorShape 包含一个或多个尺寸未知的轴(即 shape[i]=None ),则输出将包含一个 tf.RaggedTensor ,它的大小不一。
  • 如果输入元素是tf.RaggedTensor 或任何其他类型,则正常批处理。

例子:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
dataset = dataset.map(lambda x:tf.range(x))
dataset.element_spec.shape
TensorShape([None])
dataset = dataset.apply(
    tf.data.experimental.dense_to_ragged_batch(batch_size=2))
for batch in dataset:
  print(batch)
<tf.RaggedTensor [[], [0]]>
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.experimental.dense_to_ragged_batch。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。