当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.data.TFRecordDataset.rejection_resample用法及代码示例


用法

rejection_resample(
    class_func, target_dist, initial_dist=None, seed=None, name=None
)

参数

  • class_func 将输入数据集的元素映射到标量 tf.int32 张量的函数。值应该在 [0, num_classes) 中。
  • target_dist 浮点型张量,形状为 [num_classes]
  • initial_dist (可选。)浮点型张量,形状为 [num_classes] 。如果未提供,则以流媒体方式实时估计真实的类分布。
  • seed (可选。)重采样器的 Python 整数种子。
  • name (可选。) tf.data 操作的名称。

返回

  • Dataset

将数据集重新采样到目标分布的转换。

让我们考虑以下示例,其中初始数据分布为 init_dist 的数据集需要重新采样到具有 target_dist 分布的数据集。

import collections
initial_dist = [0.5, 0.5]
target_dist = [0.6, 0.4]
num_classes = len(initial_dist)
num_samples = 100000
data_np = np.random.choice(num_classes, num_samples, p=initial_dist)
dataset = tf.data.Dataset.from_tensor_slices(data_np)
x = collections.defaultdict(int)
for i in dataset:
  x[i.numpy()] += 1

根据 initial_dist 分布,x 的值将接近 {0:50000, 1:50000}

dataset = dataset.rejection_resample(
   class_func=lambda x:x % 2,
   target_dist=target_dist,
   initial_dist=initial_dist)
y = collections.defaultdict(int)
for i in dataset:
  cls, _ = i
  y[cls.numpy()] += 1

y 的值现在将接近 {0:75000, 1:50000} 从而满足 target_dist 分布。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.data.TFRecordDataset.rejection_resample。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。