當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch CachedDataset用法及代碼示例


本文簡要介紹python語言中 torch_xla.utils.cached_dataset.CachedDataset 的用法。

用法:

class torch_xla.utils.cached_dataset.CachedDataset(data_set, path, max_files_per_folder=1000, compress=True)

參數

  • data_set(torch.utils.data.Dataset) -要緩存的原始torch.utils.data.Dataset。如果所有輸入樣本都存儲在 path 文件夾中,則可以將其設置為 None

  • path(string) -應存儲/加載數據集樣本的路徑。 path 需要可寫,除非所有樣本都已存儲。 path 可以是 GCS 路徑(以 gs:// 為前綴)。

  • max_files_per_folder(Python:int) -單個文件夾中可存儲的最大文件數量。如果data_setNone,則該值將被忽略並從緩存的元數據中獲取。默認值:1000

  • compress(bool) -是否應壓縮保存的樣本。壓縮可以節省空間,但會占用壓縮/解壓縮所需的 CPU 資源。如果data_setNone,則該值將被忽略並從緩存的元數據中獲取。默認值:真

通過提供文件緩存來包裝現有的torch.utils.data.Dataset

CachedDataset 可用於將處理原始數據集所需的 CPU/RAM 資源與存儲/網絡資源進行交易。例子:

train_dataset = datasets.MNIST(
    FLAGS.datadir,
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))
train_dataset = CachedDataset(train_dataset, FLAGS.dscache_dir)

CachedDataset 將透明地緩存原始 Dataset 樣本,以便第一次運行後的每次運行都不會觸發與原始樣本處理相關的任何更多 CPU/RAM 使用。一旦CachedDataset被完全緩存,它就可以被導出(即tar.gz)並在不同的機器中使用。隻需解壓 tar.gz 並將 None 作為原始 Dataset 傳遞:示例:

train_dataset = CachedDataset(None, FLAGS.dscache_dir)

要完全緩存CachedDataset,隻需運行warmup() API。保存在 GCS 上的 CachedDataset 的優點是無需顯式導出即可在不同機器上使用。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch_xla.utils.cached_dataset.CachedDataset。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。