本文整理汇总了Python中torch.utils.data.WeightedRandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python data.WeightedRandomSampler方法的具体用法?Python data.WeightedRandomSampler怎么用?Python data.WeightedRandomSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data
的用法示例。
在下文中一共展示了data.WeightedRandomSampler方法的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: _test_auto_methods_xla
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def _test_auto_methods_xla(index, ws):
dl_type = DataLoader
if ws > 1:
from ignite.distributed.auto import _MpDeviceLoader
dl_type = _MpDeviceLoader
try:
from torch_xla.distributed.parallel_loader import MpDeviceLoader
dl_type = MpDeviceLoader
except ImportError:
pass
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, dl_type=dl_type)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10, dl_type=dl_type)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler", dl_type=dl_type)
device = "xla"
_test_auto_model_optimizer(ws, device)
示例2: _test_auto_dataloader
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):
data = torch.rand(100, 3, 12, 12)
if sampler_name is None:
sampler = None
elif sampler_name == "WeightedRandomSampler":
sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
else:
raise RuntimeError("Unknown sampler name: {}".format(sampler_name))
# Test auto_dataloader
assert idist.get_world_size() == ws
dataloader = auto_dataloader(
data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
)
assert isinstance(dataloader, dl_type)
if hasattr(dataloader, "_loader"):
dataloader = dataloader._loader
if ws < batch_size:
assert dataloader.batch_size == batch_size // ws
else:
assert dataloader.batch_size == batch_size
if ws <= num_workers:
assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
else:
assert dataloader.num_workers == num_workers
if ws < 2:
sampler_type = RandomSampler if sampler is None else type(sampler)
assert isinstance(dataloader.sampler, sampler_type)
else:
sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
assert isinstance(dataloader.sampler, sampler_type)
if isinstance(dataloader, DataLoader):
assert dataloader.pin_memory == ("cuda" in idist.device().type)
示例3: test_auto_methods_no_dist
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=10)
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")
_test_auto_model_optimizer(1, "cpu")
示例4: test_auto_methods_gloo
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def test_auto_methods_gloo(distributed_context_single_node_gloo):
ws = distributed_context_single_node_gloo["world_size"]
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler")
_test_auto_model_optimizer(ws, "cpu")
示例5: test_auto_methods_nccl
# 需要导入模块: from torch.utils import data [as 别名]
# 或者: from torch.utils.data import WeightedRandomSampler [as 别名]
def test_auto_methods_nccl(distributed_context_single_node_nccl):
ws = distributed_context_single_node_nccl["world_size"]
lrank = distributed_context_single_node_nccl["local_rank"]
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler")
device = "cuda"
_test_auto_model_optimizer(ws, device)