當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.data.experimental.CsvDataset.rejection_resample用法及代碼示例


用法

rejection_resample(
    class_func, target_dist, initial_dist=None, seed=None, name=None
)

參數

  • class_func 將輸入數據集的元素映射到標量 tf.int32 張量的函數。值應該在 [0, num_classes) 中。
  • target_dist 浮點型張量,形狀為 [num_classes]
  • initial_dist (可選。)浮點型張量,形狀為 [num_classes] 。如果未提供,則以流媒體方式實時估計真實的類分布。
  • seed (可選。)重采樣器的 Python 整數種子。
  • name (可選。) tf.data 操作的名稱。

返回

  • Dataset

將數據集重新采樣到目標分布的轉換。

讓我們考慮以下示例,其中初始數據分布為 init_dist 的數據集需要重新采樣到具有 target_dist 分布的數據集。

import collections
initial_dist = [0.5, 0.5]
target_dist = [0.6, 0.4]
num_classes = len(initial_dist)
num_samples = 100000
data_np = np.random.choice(num_classes, num_samples, p=initial_dist)
dataset = tf.data.Dataset.from_tensor_slices(data_np)
x = collections.defaultdict(int)
for i in dataset:
  x[i.numpy()] += 1

根據 initial_dist 分布,x 的值將接近 {0:50000, 1:50000}

dataset = dataset.rejection_resample(
   class_func=lambda x:x % 2,
   target_dist=target_dist,
   initial_dist=initial_dist)
y = collections.defaultdict(int)
for i in dataset:
  cls, _ = i
  y[cls.numpy()] += 1

y 的值現在將接近 {0:75000, 1:50000} 從而滿足 target_dist 分布。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.data.experimental.CsvDataset.rejection_resample。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。