用法
padded_batch(
batch_size, padded_shapes=None, padding_values=None, drop_remainder=False,
name=None
)
参数
-
batch_size
tf.int64
标量tf.Tensor
,表示要在单个批次中组合的此数据集的连续元素的数量。 -
padded_shapes
(可选。)tf.TensorShape
或tf.int64
向量tensor-like 对象的(嵌套)结构,表示每个输入元素的相应组件在批处理之前应填充到的形状。任何未知维度都将被填充到每批中该维度的最大大小。如果未设置,则将所有组件的所有尺寸填充到批次中的最大尺寸。如果任何组件具有未知等级,则必须设置padded_shapes
。 -
padding_values
(可选。)标量形tf.Tensor
的(嵌套)结构,表示用于各个组件的填充值。 None 表示应该用默认值填充(嵌套)结构。数字类型的默认值为0
,字符串类型的默认值为空字符串。padding_values
应该具有与输入数据集相同的(嵌套)结构。如果padding_values
是单个元素并且输入数据集有多个组件,那么相同的padding_values
将用于填充数据集的每个组件。如果padding_values
是一个标量,那么它的值将被广播以匹配每个组件的形状。 -
drop_remainder
(可选。)一个tf.bool
标量tf.Tensor
,表示在最后一批少于batch_size
元素的情况下是否应删除它;默认行为是不丢弃较小的批次。 -
name
(可选。) tf.data 操作的名称。
返回
-
Dataset
一个Dataset
。
抛出
-
ValueError
如果组件具有未知等级,并且未设置padded_shapes
参数。 -
TypeError
如果组件的类型不受支持。支持的类型列表记录在https://www.tensorflow.org/guide/data#dataset_structure
将此数据集的连续元素组合成填充批次。
此转换将输入数据集的多个连续元素组合成一个元素。
与 tf.data.Dataset.batch
一样,结果元素的组件将有一个额外的外部维度,即 batch_size
(如果 batch_size
不均匀地划分输入元素的数量 N
,则为最后一个元素的 N % batch_size
并且drop_remainder
是 False
)。如果您的程序依赖于具有相同外部尺寸的批次,则应将 drop_remainder
参数设置为 True
以防止生成较小的批次。
与 tf.data.Dataset.batch
不同,要批处理的输入元素可能具有不同的形状,并且此转换会将每个组件填充到 padded_shapes
中的相应形状。 padded_shapes
参数确定输出元素中每个组件的每个维度的结果形状:
- 如果尺寸是常数,则组件将在该尺寸中填充到该长度。
- 如果维度未知,组件将被填充到该维度中所有元素的最大长度。
A = (tf.data.Dataset
.range(1, 5, output_type=tf.int32)
.map(lambda x:tf.fill([x], x)))
# Pad to the smallest per-batch size that fits all elements.
B = A.padded_batch(2)
for element in B.as_numpy_iterator():
print(element)
[[1 0]
[2 2]]
[[3 3 3 0]
[4 4 4 4]]
# Pad to a fixed size.
C = A.padded_batch(2, padded_shapes=5)
for element in C.as_numpy_iterator():
print(element)
[[1 0 0 0 0]
[2 2 0 0 0]]
[[3 3 3 0 0]
[4 4 4 4 0]]
# Pad with a custom value.
D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
for element in D.as_numpy_iterator():
print(element)
[[ 1 -1 -1 -1 -1]
[ 2 2 -1 -1 -1]]
[[ 3 3 3 -1 -1]
[ 4 4 4 4 -1]]
# Components of nested elements can be padded independently.
elements = [([1, 2, 3], [10]),
([4, 5], [11, 12])]
dataset = tf.data.Dataset.from_generator(
lambda:iter(elements), (tf.int32, tf.int32))
# Pad the first component of the tuple to length 4, and the second
# component to the smallest size that fits.
dataset = dataset.padded_batch(2,
padded_shapes=([4], [None]),
padding_values=(-1, 100))
list(dataset.as_numpy_iterator())
[(array([[ 1, 2, 3, -1], [ 4, 5, -1, -1]], dtype=int32),
array([[ 10, 100], [ 11, 12]], dtype=int32))]
# Pad with a single value and multiple components.
E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1)
for element in E.as_numpy_iterator():
print(element)
(array([[ 1, -1],
[ 2, 2]], dtype=int32), array([[ 1, -1],
[ 2, 2]], dtype=int32))
(array([[ 3, 3, 3, -1],
[ 4, 4, 4, 4]], dtype=int32), array([[ 3, 3, 3, -1],
[ 4, 4, 4, 4]], dtype=int32))
另请参见 tf.data.experimental.dense_to_sparse_batch
,它将可能具有不同形状的元素组合成 tf.sparse.SparseTensor
。
相关用法
- Python tf.data.experimental.CsvDataset.prefetch用法及代码示例
- Python tf.data.experimental.CsvDataset.window用法及代码示例
- Python tf.data.experimental.CsvDataset.apply用法及代码示例
- Python tf.data.experimental.CsvDataset.flat_map用法及代码示例
- Python tf.data.experimental.CsvDataset.random用法及代码示例
- Python tf.data.experimental.CsvDataset.cardinality用法及代码示例
- Python tf.data.experimental.CsvDataset.interleave用法及代码示例
- Python tf.data.experimental.CsvDataset.group_by_window用法及代码示例
- Python tf.data.experimental.CsvDataset.as_numpy_iterator用法及代码示例
- Python tf.data.experimental.CsvDataset.from_generator用法及代码示例
- Python tf.data.experimental.CsvDataset.range用法及代码示例
- Python tf.data.experimental.CsvDataset.unique用法及代码示例
- Python tf.data.experimental.CsvDataset.shard用法及代码示例
- Python tf.data.experimental.CsvDataset.choose_from_datasets用法及代码示例
- Python tf.data.experimental.CsvDataset.batch用法及代码示例
- Python tf.data.experimental.CsvDataset.enumerate用法及代码示例
- Python tf.data.experimental.CsvDataset.from_tensors用法及代码示例
- Python tf.data.experimental.CsvDataset.bucket_by_sequence_length用法及代码示例
- Python tf.data.experimental.CsvDataset.concatenate用法及代码示例
- Python tf.data.experimental.CsvDataset.unbatch用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.experimental.CsvDataset.padded_batch。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。