当前位置: 首页>>代码示例>>Python>>正文


Python data.WeightedRandomSampler方法代码示例

本文整理汇总了Python中torch.utils.data.WeightedRandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.WeightedRandomSampler方法的具体用法?Python data.WeightedRandomSampler怎么用?Python data.WeightedRandomSampler使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data的用法示例。


在下文中一共展示了data.WeightedRandomSampler方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: _test_auto_methods_xla

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def _test_auto_methods_xla(index, ws):

    dl_type = DataLoader
    if ws > 1:

        from ignite.distributed.auto import _MpDeviceLoader

        dl_type = _MpDeviceLoader
        try:
            from torch_xla.distributed.parallel_loader import MpDeviceLoader

            dl_type = MpDeviceLoader
        except ImportError:
            pass

    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, dl_type=dl_type)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10, dl_type=dl_type)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler", dl_type=dl_type)

    device = "xla"
    _test_auto_model_optimizer(ws, device) 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_auto.py

示例2: _test_auto_dataloader

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(
        data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
    )

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type) 
开发者ID:pytorch,项目名称:ignite,代码行数:39,代码来源:test_auto.py

示例3: test_auto_methods_no_dist

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def test_auto_methods_no_dist():

    _test_auto_dataloader(1, 1, batch_size=1)
    _test_auto_dataloader(1, 1, batch_size=10, num_workers=10)
    _test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(1, "cpu") 
开发者ID:pytorch,项目名称:ignite,代码行数:9,代码来源:test_auto.py

示例4: test_auto_methods_gloo

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def test_auto_methods_gloo(distributed_context_single_node_gloo):

    ws = distributed_context_single_node_gloo["world_size"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(ws, "cpu") 
开发者ID:pytorch,项目名称:ignite,代码行数:10,代码来源:test_auto.py

示例5: test_auto_methods_nccl

# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def test_auto_methods_nccl(distributed_context_single_node_nccl):

    ws = distributed_context_single_node_nccl["world_size"]
    lrank = distributed_context_single_node_nccl["local_rank"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler")

    device = "cuda"
    _test_auto_model_optimizer(ws, device) 
开发者ID:pytorch,项目名称:ignite,代码行数:12,代码来源:test_auto.py


注:本文中的torch.utils.data.WeightedRandomSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。