當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.data.experimental.dense_to_ragged_batch用法及代碼示例


將參差不齊的元素批量轉換為 tf.RaggedTensor 的轉換。

用法

tf.data.experimental.dense_to_ragged_batch(
    batch_size, drop_remainder=False, row_splits_dtype=tf.dtypes.int64
)

參數

  • batch_size tf.int64 標量 tf.Tensor ,表示要在單個批次中組合的此數據集的連續元素的數量。
  • drop_remainder (可選。)一個 tf.bool 標量 tf.Tensor ,表示在最後一批少於 batch_size 元素的情況下是否應刪除它;默認行為是不丟棄較小的批次。
  • row_splits_dtype 應該用於任何新不規則張量的 row_splits 的 dtype。現有 tf.RaggedTensor 元素的 row_splits dtype 未更改。

返回

  • Dataset 一個Dataset

此轉換將輸入數據集的多個連續元素組合成一個元素。

tf.data.Dataset.batch 一樣,結果元素的組件將有一個額外的外部維度,即 batch_size (如果 batch_size 不均勻地劃分輸入元素的數量 N ,則為最後一個元素的 N % batch_size 並且drop_remainderFalse )。如果您的程序依賴於具有相同外部尺寸的批次,則應將 drop_remainder 參數設置為 True 以防止生成較小的批次。

tf.data.Dataset.batch 不同,要批處理的輸入元素可能具有不同的形狀:

  • 如果輸入元素是一個 tf.Tensor,其靜態 tf.TensorShape 已完全定義,則它按正常方式進行批處理。
  • 如果輸入元素是 tf.Tensor,其靜態 tf.TensorShape 包含一個或多個尺寸未知的軸(即 shape[i]=None ),則輸出將包含一個 tf.RaggedTensor ,它的大小不一。
  • 如果輸入元素是tf.RaggedTensor 或任何其他類型,則正常批處理。

例子:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
dataset = dataset.map(lambda x:tf.range(x))
dataset.element_spec.shape
TensorShape([None])
dataset = dataset.apply(
    tf.data.experimental.dense_to_ragged_batch(batch_size=2))
for batch in dataset:
  print(batch)
<tf.RaggedTensor [[], [0]]>
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.dense_to_ragged_batch。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。