当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch CachedDataset用法及代码示例


本文简要介绍python语言中 torch_xla.utils.cached_dataset.CachedDataset 的用法。

用法:

class torch_xla.utils.cached_dataset.CachedDataset(data_set, path, max_files_per_folder=1000, compress=True)

参数

  • data_set(torch.utils.data.Dataset) -要缓存的原始torch.utils.data.Dataset。如果所有输入样本都存储在 path 文件夹中,则可以将其设置为 None

  • path(string) -应存储/加载数据集样本的路径。 path 需要可写,除非所有样本都已存储。 path 可以是 GCS 路径(以 gs:// 为前缀)。

  • max_files_per_folder(Python:int) -单个文件夹中可存储的最大文件数量。如果data_setNone,则该值将被忽略并从缓存的元数据中获取。默认值:1000

  • compress(bool) -是否应压缩保存的样本。压缩可以节省空间,但会占用压缩/解压缩所需的 CPU 资源。如果data_setNone,则该值将被忽略并从缓存的元数据中获取。默认值:真

通过提供文件缓存来包装现有的torch.utils.data.Dataset

CachedDataset 可用于将处理原始数据集所需的 CPU/RAM 资源与存储/网络资源进行交易。例子:

train_dataset = datasets.MNIST(
    FLAGS.datadir,
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))
train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir)

CachedDataset 将透明地缓存原始 Dataset 样本,以便第一次运行后的每次运行都不会触发与原始样本处理相关的任何更多 CPU/RAM 使用。一旦CachedDataset被完全缓存,它就可以被导出(即tar.gz)并在不同的机器中使用。只需解压 tar.gz 并将 None 作为原始 Dataset 传递:示例:

train_dataset = CachedDataset(None, FLAGS.dscache_dir)

要完全缓存CachedDataset,只需运行warmup() API。保存在 GCS 上的 CachedDataset 的优点是无需显式导出即可在不同机器上使用。

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch_xla.utils.cached_dataset.CachedDataset。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。