当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch rand_split_train_val用法及代码示例


本文简要介绍python语言中 torchrec.datasets.utils.rand_split_train_val 的用法。

用法:

torchrec.datasets.utils.rand_split_train_val(datapipe: torch.utils.data.dataset.IterDataPipe, train_perc: float, random_seed: int = 0) → Tuple[torch.utils.data.dataset.IterDataPipe, torch.utils.data.dataset.IterDataPipe]

参数

  • datapipe(IterDataPipe) -要拆分的数据管道。

  • train_perc(浮点数) -范围 (0.0, 1.0) 中的值,指定要包含在训练拆分中的数据管道样本的目标比例。请注意,不能保证实际比例与train_perc 完全匹配。

  • random_seed(int) -确定给定样本和train_perc 的拆分成员资格。在调用中使用相同的值来生成一致的拆分。

通过均匀随机采样,生成两个 IterDataPipe 实例,表示给定 IterDataPipe 的不相交的 train 和 val 分割。

例子:

datapipe = criteo_terabyte(
    ("/home/datasets/criteo/day_0.tsv", "/home/datasets/criteo/day_1.tsv")
)
train_datapipe, val_datapipe = rand_split_train_val(datapipe, 0.75)
train_batch = next(iter(train_datapipe))
val_batch = next(iter(val_datapipe))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torchrec.datasets.utils.rand_split_train_val。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。