本文整理汇总了Python中torch.utils.data.sampler.WeightedRandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.WeightedRandomSampler方法的具体用法?Python sampler.WeightedRandomSampler怎么用?Python sampler.WeightedRandomSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.sampler
的用法示例。
在下文中一共展示了sampler.WeightedRandomSampler方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: train
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def train(self):
for epoch_id in range(self.max_epochs):
self.epoch_id = epoch_id
embedding = self.embed_all()
weights, labels, centers = self.cluster(embedding)
self.each_cluster(embedding, labels)
self.data.labels = labels
self.train_data = None
self.train_data = DataLoader(
self.data, batch_size=self.batch_size, num_workers=8,
sampler=WeightedRandomSampler(weights, len(self.data) * 4, replacement=True)
)
for data, label in self.train_data:
self.step(data, label, centers)
if self.step_id % self.checkpoint_interval == 0:
self.checkpoint()
self.step_id += 1
return self.net
示例2: setup_sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def setup_sampler(sampler_type, num_iters, batch_size):
if sampler_type is None:
return None, batch_size
if sampler_type == "weighted":
from torch.utils.data.sampler import WeightedRandomSampler
w = torch.ones(num_iters * batch_size, dtype=torch.float)
for i in range(num_iters):
w[batch_size * i : batch_size * (i + 1)] += i * 1.0
return WeightedRandomSampler(w, num_samples=num_iters * batch_size, replacement=True), batch_size
if sampler_type == "distributed":
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
num_replicas = 1
rank = 0
if dist.is_available() and dist.is_initialized():
num_replicas = dist.get_world_size()
rank = dist.get_rank()
dataset = torch.zeros(num_iters * batch_size)
return DistributedSampler(dataset, num_replicas=num_replicas, rank=rank), batch_size // num_replicas
示例3: test_dist_proxy_sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_dist_proxy_sampler():
import torch
from torch.utils.data import WeightedRandomSampler
weights = torch.ones(100)
weights[:50] += 1
num_samples = 100
sampler = WeightedRandomSampler(weights, num_samples)
num_replicas = 4
dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]
torch.manual_seed(0)
true_indices = list(sampler)
indices_per_rank = []
for s in dist_samplers:
s.set_epoch(0)
indices_per_rank += list(s)
assert set(indices_per_rank) == set(true_indices)
示例4: sampler
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def sampler(self, examples_per_epoch=None):
total_length = len(self)
if examples_per_epoch is None:
examples_per_epoch = total_length
# Sample with replacement only if we have to
replacement = examples_per_epoch > total_length
return WeightedRandomSampler(
torch.ones(total_length).double(),
examples_per_epoch,
replacement=replacement
)
示例5: _test_auto_dataloader
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):
data = torch.rand(100, 3, 12, 12)
if sampler_name is None:
sampler = None
elif sampler_name == "WeightedRandomSampler":
sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
else:
raise RuntimeError("Unknown sampler name: {}".format(sampler_name))
# Test auto_dataloader
assert idist.get_world_size() == ws
dataloader = auto_dataloader(
data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
)
assert isinstance(dataloader, dl_type)
if hasattr(dataloader, "_loader"):
dataloader = dataloader._loader
if ws < batch_size:
assert dataloader.batch_size == batch_size // ws
else:
assert dataloader.batch_size == batch_size
if ws <= num_workers:
assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
else:
assert dataloader.num_workers == num_workers
if ws < 2:
sampler_type = RandomSampler if sampler is None else type(sampler)
assert isinstance(dataloader.sampler, sampler_type)
else:
sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
assert isinstance(dataloader.sampler, sampler_type)
if isinstance(dataloader, DataLoader):
assert dataloader.pin_memory == ("cuda" in idist.device().type)
示例6: test_auto_methods_no_dist
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_auto_methods_no_dist():
_test_auto_dataloader(1, 1, batch_size=1)
_test_auto_dataloader(1, 1, batch_size=10, num_workers=10)
_test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")
_test_auto_model_optimizer(1, "cpu")
示例7: test_auto_methods_gloo
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_auto_methods_gloo(distributed_context_single_node_gloo):
ws = distributed_context_single_node_gloo["world_size"]
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler")
_test_auto_model_optimizer(ws, "cpu")
示例8: test_auto_methods_nccl
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_auto_methods_nccl(distributed_context_single_node_nccl):
ws = distributed_context_single_node_nccl["world_size"]
lrank = distributed_context_single_node_nccl["local_rank"]
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
_test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler")
device = "cuda"
_test_auto_model_optimizer(ws, device)
示例9: fit
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def fit(self, X, y=None) -> None:
documents, features = X.shape
ds = CountTensorDataset(X.astype(np.float32))
self.autoencoder = ProdLDA(
in_dimension=features,
hidden1_dimension=self.hidden1_dimension,
hidden2_dimension=self.hidden2_dimension,
topics=self.topics,
)
if self.cuda:
self.autoencoder.cuda()
ae_optimizer = Adam(
self.autoencoder.parameters(), lr=self.lr, betas=(0.99, 0.999)
)
train(
ds,
self.autoencoder,
cuda=self.cuda,
validation=None,
epochs=self.epochs,
batch_size=self.batch_size,
optimizer=ae_optimizer,
sampler=WeightedRandomSampler(
torch.ones(documents), min(documents, self.samples)
),
silent=True,
num_workers=0, # TODO causes a bug to change this on Mac
)
示例10: __init__
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def __init__(self, data_source):
lebel_freq = {}
for idx in range(len(data_source)):
label = data_source.items[idx]['language']
if label in lebel_freq: lebel_freq[label] += 1
else: lebel_freq[label] = 1
total = float(sum(lebel_freq.values()))
weights = [total / lebel_freq[data_source.items[idx]['language']] for idx in range(len(data_source))]
self._sampler = WeightedRandomSampler(weights, len(weights))
示例11: get_train_data_source
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def get_train_data_source(ds_metainfo,
batch_size,
num_workers):
"""
Get data source for training subset.
Parameters
----------
ds_metainfo : DatasetMetaInfo
Dataset metainfo.
batch_size : int
Batch size.
num_workers : int
Number of background workers.
Returns
-------
DataLoader
Data source.
"""
transform_train = ds_metainfo.train_transform(ds_metainfo=ds_metainfo)
kwargs = ds_metainfo.dataset_class_extra_kwargs if ds_metainfo.dataset_class_extra_kwargs is not None else {}
dataset = ds_metainfo.dataset_class(
root=ds_metainfo.root_dir_path,
mode="train",
transform=transform_train,
**kwargs)
ds_metainfo.update_from_dataset(dataset)
if not ds_metainfo.train_use_weighted_sampler:
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
else:
sampler = WeightedRandomSampler(
weights=dataset.sample_weights,
num_samples=len(dataset))
return DataLoader(
dataset=dataset,
batch_size=batch_size,
# shuffle=True,
sampler=sampler,
num_workers=num_workers,
pin_memory=True)
示例12: train_valid
# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def train_valid(data_root, name, img_enc, pc_enc, pc_dec, optimizer, scheduler, adain=True, projection=True,
decimation=None, color_img=False, n_points=250, bs=4, lr=5e-5, weight_decay=1e-5, gamma=.3,
milestones=(5, 8), n_epochs=10, print_freq=1000, val_freq=10000, checkpoint_folder=None):
if decimation is not None:
pc_dec = partial(pc_dec, decimation=decimation)
net = PointcloudDeformNet((bs,) + (3 if color_img else 1, 224, 224), (bs, n_points, 3), img_enc, pc_enc, pc_dec,
adain=adain, projection=projection,
optimizer=lambda x: optimizer(x, lr, weight_decay=weight_decay),
scheduler=lambda x: scheduler(x, milestones=milestones, gamma=gamma),
weight_decay=None)
print(net)
train_data = ShapeNet(path=data_root, grayscale=not color_img, type='train', n_points=n_points)
sampler = WeightedRandomSampler(train_data.sample_weights, len(train_data), True)
train_loader = DataLoader(train_data, batch_size=bs, num_workers=1, collate_fn=collate, drop_last=True,
sampler=sampler)
val_data = ShapeNet(path=data_root, grayscale=not color_img, type='valid', num_vals=10 * len(os.listdir(data_root)),
n_points=n_points)
val_loader = DataLoader(val_data, batch_size=bs, shuffle=False, num_workers=1, collate_fn=collate, drop_last=True)
if checkpoint_folder is None:
mon = nnt.Monitor(name, print_freq=print_freq, num_iters=len(train_data) // bs, use_tensorboard=True)
mon.copy_files(backup_files)
mon.dump_rep('network', net)
mon.dump_rep('optimizer', net.optim['optimizer'])
if net.optim['scheduler']:
mon.dump_rep('scheduler', net.optim['scheduler'])
states = {
'model_state_dict': net.state_dict(),
'opt_state_dict': net.optim['optimizer'].state_dict()
}
if net.optim['scheduler']:
states['scheduler_state_dict'] = net.optim['scheduler'].state_dict()
mon.schedule(mon.dump, beginning=False, name='training.pt', obj=states, type='torch', keep=5)
print('Training...')
else:
mon = nnt.Monitor(current_folder=checkpoint_folder, print_freq=print_freq, num_iters=len(train_data) // bs,
use_tensorboard=True)
states = mon.load('training.pt', type='torch')
mon.set_iter(mon.get_epoch() * len(train_data) // bs)
net.load_state_dict(states['model_state_dict'])
net.optim['optimizer'].load_state_dict(states['opt_state_dict'])
if net.optim['scheduler']:
net.optim['scheduler'].load_state_dict(states['scheduler_state_dict'])
print('Resume from epoch %d...' % mon.get_epoch())
mon.run_training(net, train_loader, n_epochs, val_loader, valid_freq=val_freq, reduce='mean')
print('Training finished!')