當前位置: 首頁>>代碼示例>>Python>>正文


Python data.WeightedRandomSampler方法代碼示例

本文整理匯總了Python中torch.utils.data.WeightedRandomSampler方法的典型用法代碼示例。如果您正苦於以下問題:Python data.WeightedRandomSampler方法的具體用法?Python data.WeightedRandomSampler怎麽用?Python data.WeightedRandomSampler使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data的用法示例。


在下文中一共展示了data.WeightedRandomSampler方法的5個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: _test_auto_methods_xla

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 別名]
def _test_auto_methods_xla(index, ws):

    dl_type = DataLoader
    if ws > 1:

        from ignite.distributed.auto import _MpDeviceLoader

        dl_type = _MpDeviceLoader
        try:
            from torch_xla.distributed.parallel_loader import MpDeviceLoader

            dl_type = MpDeviceLoader
        except ImportError:
            pass

    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, dl_type=dl_type)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10, dl_type=dl_type)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler", dl_type=dl_type)

    device = "xla"
    _test_auto_model_optimizer(ws, device) 
開發者ID:pytorch,項目名稱:ignite,代碼行數:23,代碼來源:test_auto.py

示例2: _test_auto_dataloader

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 別名]
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(
        data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
    )

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type) 
開發者ID:pytorch,項目名稱:ignite,代碼行數:39,代碼來源:test_auto.py

示例3: test_auto_methods_no_dist

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 別名]
def test_auto_methods_no_dist():

    _test_auto_dataloader(1, 1, batch_size=1)
    _test_auto_dataloader(1, 1, batch_size=10, num_workers=10)
    _test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(1, "cpu") 
開發者ID:pytorch,項目名稱:ignite,代碼行數:9,代碼來源:test_auto.py

示例4: test_auto_methods_gloo

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 別名]
def test_auto_methods_gloo(distributed_context_single_node_gloo):

    ws = distributed_context_single_node_gloo["world_size"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(ws, "cpu") 
開發者ID:pytorch,項目名稱:ignite,代碼行數:10,代碼來源:test_auto.py

示例5: test_auto_methods_nccl

# 需要導入模塊: from torch.utils import data [as 別名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 別名]
def test_auto_methods_nccl(distributed_context_single_node_nccl):

    ws = distributed_context_single_node_nccl["world_size"]
    lrank = distributed_context_single_node_nccl["local_rank"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler")

    device = "cuda"
    _test_auto_model_optimizer(ws, device) 
開發者ID:pytorch,項目名稱:ignite,代碼行數:12,代碼來源:test_auto.py


注:本文中的torch.utils.data.WeightedRandomSampler方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。