将参差不齐的元素批量转换为 tf.RaggedTensor
的转换。
用法
tf.data.experimental.dense_to_ragged_batch(
batch_size, drop_remainder=False, row_splits_dtype=tf.dtypes.int64
)
参数
-
batch_size
tf.int64
标量tf.Tensor
,表示要在单个批次中组合的此数据集的连续元素的数量。 -
drop_remainder
(可选。)一个tf.bool
标量tf.Tensor
,表示在最后一批少于batch_size
元素的情况下是否应删除它;默认行为是不丢弃较小的批次。 -
row_splits_dtype
应该用于任何新不规则张量的row_splits
的 dtype。现有tf.RaggedTensor
元素的 row_splits dtype 未更改。
返回
-
Dataset
一个Dataset
。
此转换将输入数据集的多个连续元素组合成一个元素。
与 tf.data.Dataset.batch
一样,结果元素的组件将有一个额外的外部维度,即 batch_size
(如果 batch_size
不均匀地划分输入元素的数量 N
,则为最后一个元素的 N % batch_size
并且drop_remainder
是 False
)。如果您的程序依赖于具有相同外部尺寸的批次,则应将 drop_remainder
参数设置为 True
以防止生成较小的批次。
与 tf.data.Dataset.batch
不同,要批处理的输入元素可能具有不同的形状:
- 如果输入元素是一个
tf.Tensor
,其静态tf.TensorShape
已完全定义,则它按正常方式进行批处理。 - 如果输入元素是
tf.Tensor
,其静态tf.TensorShape
包含一个或多个尺寸未知的轴(即shape[i]=None
),则输出将包含一个tf.RaggedTensor
,它的大小不一。 - 如果输入元素是
tf.RaggedTensor
或任何其他类型,则正常批处理。
例子:
dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
dataset = dataset.map(lambda x:tf.range(x))
dataset.element_spec.shape
TensorShape([None])
dataset = dataset.apply(
tf.data.experimental.dense_to_ragged_batch(batch_size=2))
for batch in dataset:
print(batch)
<tf.RaggedTensor [[], [0]]>
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>
相关用法
- Python tf.data.experimental.dense_to_sparse_batch用法及代码示例
- Python tf.data.experimental.RandomDataset.group_by_window用法及代码示例
- Python tf.data.experimental.SqlDataset.enumerate用法及代码示例
- Python tf.data.experimental.make_saveable_from_iterator用法及代码示例
- Python tf.data.experimental.SqlDataset.zip用法及代码示例
- Python tf.data.experimental.Counter用法及代码示例
- Python tf.data.experimental.SqlDataset.shard用法及代码示例
- Python tf.data.experimental.CsvDataset.window用法及代码示例
- Python tf.data.experimental.RandomDataset.cache用法及代码示例
- Python tf.data.experimental.SqlDataset.snapshot用法及代码示例
- Python tf.data.experimental.CsvDataset.apply用法及代码示例
- Python tf.data.experimental.DatasetInitializer用法及代码示例
- Python tf.data.experimental.ignore_errors用法及代码示例
- Python tf.data.experimental.unbatch用法及代码示例
- Python tf.data.experimental.RandomDataset.map用法及代码示例
- Python tf.data.experimental.CsvDataset.flat_map用法及代码示例
- Python tf.data.experimental.assert_cardinality用法及代码示例
- Python tf.data.experimental.CsvDataset.random用法及代码示例
- Python tf.data.experimental.save用法及代码示例
- Python tf.data.experimental.CsvDataset.cardinality用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.experimental.dense_to_ragged_batch。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。