當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.data.experimental.RandomDataset.shuffle用法及代碼示例


用法

shuffle(
    buffer_size, seed=None, reshuffle_each_iteration=None, name=None
)

參數

  • buffer_size 一個 tf.int64 標量 tf.Tensor ,表示新數據集將從中采樣的該數據集中的元素數。
  • seed (可選。)tf.int64 標量 tf.Tensor ,表示將用於創建分布的隨機種子。有關行為,請參閱tf.random.set_seed
  • reshuffle_each_iteration (可選。)一個布爾值,如果為真,則表示每次迭代數據集時都應該偽隨機地重新洗牌。 (默認為 True 。)
  • name (可選。) tf.data 操作的名稱。

返回

  • Dataset 一個Dataset

隨機打亂這個數據集的元素。

該數據集用buffer_size 元素填充緩衝區,然後從該緩衝區中隨機采樣元素,用新元素替換所選元素。對於完美的混洗,需要大於或等於數據集完整大小的緩衝區大小。

例如,如果您的數據集包含 10,000 個元素,但 buffer_size 設置為 1,000,則 shuffle 最初將僅從緩衝區中的前 1,000 個元素中選擇一個隨機元素。一旦選擇了一個元素,它在緩衝區中的空間就會被下一個(即第 1,001 個)元素替換,從而保持 1,000 個元素的緩衝區。

reshuffle_each_iteration 控製每個 epoch 的 shuffle 順序是否應該不同。在 TF 1.X 中,創建 epoch 的慣用方式是通過 repeat 轉換:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 2, 0]

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 0, 2]

在 TF 2.0 中,tf.data.Dataset 對象是 Python 可迭代對象,這使得也可以通過 Python 迭代創建 epoch:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 0, 2]

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.RandomDataset.shuffle。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。