当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.data.experimental.RandomDataset.shuffle用法及代码示例


用法

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None, name=None
)

参数

  • buffer_size 一个 tf.int64 标量 tf.Tensor ,表示新数据集将从中采样的该数据集中的元素数。
  • seed (可选。)tf.int64 标量 tf.Tensor ,表示将用于创建分布的随机种子。有关行为,请参阅tf.random.set_seed
  • reshuffle_each_iteration (可选。)一个布尔值,如果为真,则表示每次迭代数据集时都应该伪随机地重新洗牌。 (默认为 True 。)
  • name (可选。) tf.data 操作的名称。

返回

  • Dataset 一个Dataset

随机打乱这个数据集的元素。

该数据集用buffer_size 元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。对于完美的混洗,需要大于或等于数据集完整大小的缓冲区大小。

例如,如果您的数据集包含 10,000 个元素,但 buffer_size 设置为 1,000,则 shuffle 最初将仅从缓冲区中的前 1,000 个元素中选择一个随机元素。一旦选择了一个元素,它在缓冲区中的空间就会被下一个(即第 1,001 个)元素替换,从而保持 1,000 个元素的缓冲区。

reshuffle_each_iteration 控制每个 epoch 的 shuffle 顺序是否应该不同。在 TF 1.X 中,创建 epoch 的惯用方式是通过 repeat 转换:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 2, 0]

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 0, 2]

在 TF 2.0 中,tf.data.Dataset 对象是 Python 可迭代对象,这使得也可以通过 Python 迭代创建 epoch:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 0, 2]

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.data.experimental.RandomDataset.shuffle。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。