用法
bucket_by_sequence_length(
element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes=None,
padding_values=None, pad_to_bucket_boundary=False, no_padding=False,
drop_remainder=False, name=None
)参数
-
element_length_func从Dataset到tf.int32中的元素的函数,确定元素的长度,这将确定它进入的桶。 -
bucket_boundarieslist<int>,桶的上长度边界。 -
bucket_batch_sizeslist<int>,每个桶的批量大小。长度应为len(bucket_boundaries) + 1。 -
padded_shapestf.TensorShape的嵌套结构传递给tf.data.Dataset.padded_batch。如果未提供,将使用dataset.output_shapes,这将导致可变长度维度在每批中被填充到最大长度。 -
padding_values要填充的值,传递给tf.data.Dataset.padded_batch。默认填充为 0。 -
pad_to_bucket_boundarybool,如果False,将批量填充未知大小的尺寸到最大长度。如果True,会将未知大小的维度填充到桶边界减 1(即每个桶中的最大长度),并且调用者必须确保源Dataset不包含任何长度超过max(bucket_boundaries)的元素。 -
no_paddingbool,表示是否填充批量特征(特征需要是tf.sparse.SparseTensor类型或相同形状)。 -
drop_remainder(可选。)一个tf.bool标量tf.Tensor,表示在最后一批少于batch_size元素的情况下是否应删除它;默认行为是不丢弃较小的批次。 -
name(可选。) tf.data 操作的名称。
返回
-
一个
Dataset。
抛出
-
ValueError如果len(bucket_batch_sizes) != len(bucket_boundaries) + 1.
一种按长度对Dataset 中的元素进行分桶的转换。
Dataset 的元素按长度分组在一起,然后进行填充和批处理。
这对于元素具有可变长度的序列任务很有用。将具有相似长度的元素组合在一起可以减少批次中填充的总比例,从而提高训练步骤的效率。
下面是一个基于序列长度将输入数据分桶到 3 个桶“[0, 3), [3, 5), [5, inf)” 的示例,批量大小为 2。
elements = [
[0], [1, 2, 3, 4], [5, 6, 7],
[7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
dataset = tf.data.Dataset.from_generator(
lambda:elements, tf.int64, output_shapes=[None])
dataset = dataset.bucket_by_sequence_length(
element_length_func=lambda elem:tf.shape(elem)[0],
bucket_boundaries=[3, 5],
bucket_batch_sizes=[2, 2, 2])
for elem in dataset.as_numpy_iterator():
print(elem)
[[1 2 3 4]
[5 6 7 0]]
[[ 7 8 9 10 11 0]
[13 14 15 16 19 20]]
[[ 0 0]
[21 22]]
相关用法
- Python tf.compat.v1.data.TFRecordDataset.batch用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.concatenate用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.cardinality用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.prefetch用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.enumerate用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.filter用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.unbatch用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.window用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.repeat用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.group_by_window用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.with_options用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.reduce用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.skip用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.snapshot用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.map用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.random用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.take用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.make_initializable_iterator用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.scan用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.data.TFRecordDataset.bucket_by_sequence_length。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
