当前位置: 首页>>代码示例>>Python>>正文


Python sampler.WeightedRandomSampler方法代码示例

本文整理汇总了Python中torch.utils.data.sampler.WeightedRandomSampler方法的典型用法代码示例。如果您正苦于以下问题:Python sampler.WeightedRandomSampler方法的具体用法?Python sampler.WeightedRandomSampler怎么用?Python sampler.WeightedRandomSampler使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.sampler的用法示例。


在下文中一共展示了sampler.WeightedRandomSampler方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: train

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def train(self):
    for epoch_id in range(self.max_epochs):
      self.epoch_id = epoch_id
      embedding = self.embed_all()
      weights, labels, centers = self.cluster(embedding)

      self.each_cluster(embedding, labels)

      self.data.labels = labels
      self.train_data = None
      self.train_data = DataLoader(
        self.data, batch_size=self.batch_size, num_workers=8,
        sampler=WeightedRandomSampler(weights, len(self.data) * 4, replacement=True)
      )
      for data, label in self.train_data:
        self.step(data, label, centers)
        if self.step_id % self.checkpoint_interval == 0:
          self.checkpoint()
        self.step_id += 1

    return self.net 
开发者ID:mjendrusch,项目名称:torchsupport,代码行数:23,代码来源:clustering.py

示例2: setup_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def setup_sampler(sampler_type, num_iters, batch_size):
    if sampler_type is None:
        return None, batch_size

    if sampler_type == "weighted":
        from torch.utils.data.sampler import WeightedRandomSampler

        w = torch.ones(num_iters * batch_size, dtype=torch.float)
        for i in range(num_iters):
            w[batch_size * i : batch_size * (i + 1)] += i * 1.0
        return WeightedRandomSampler(w, num_samples=num_iters * batch_size, replacement=True), batch_size

    if sampler_type == "distributed":
        from torch.utils.data.distributed import DistributedSampler
        import torch.distributed as dist

        num_replicas = 1
        rank = 0
        if dist.is_available() and dist.is_initialized():
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()

        dataset = torch.zeros(num_iters * batch_size)
        return DistributedSampler(dataset, num_replicas=num_replicas, rank=rank), batch_size // num_replicas 
开发者ID:pytorch,项目名称:ignite,代码行数:26,代码来源:__init__.py

示例3: test_dist_proxy_sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_dist_proxy_sampler():
    import torch
    from torch.utils.data import WeightedRandomSampler

    weights = torch.ones(100)
    weights[:50] += 1
    num_samples = 100
    sampler = WeightedRandomSampler(weights, num_samples)

    num_replicas = 4
    dist_samplers = [DistributedProxySampler(sampler, num_replicas=num_replicas, rank=i) for i in range(num_replicas)]

    torch.manual_seed(0)
    true_indices = list(sampler)

    indices_per_rank = []
    for s in dist_samplers:
        s.set_epoch(0)
        indices_per_rank += list(s)

    assert set(indices_per_rank) == set(true_indices) 
开发者ID:pytorch,项目名称:ignite,代码行数:23,代码来源:test_auto.py

示例4: sampler

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def sampler(self, examples_per_epoch=None):
        total_length = len(self)
        if examples_per_epoch is None:
            examples_per_epoch = total_length

        # Sample with replacement only if we have to
        replacement = examples_per_epoch > total_length

        return WeightedRandomSampler(
            torch.ones(total_length).double(),
            examples_per_epoch,
            replacement=replacement
        ) 
开发者ID:anibali,项目名称:margipose,代码行数:15,代码来源:__init__.py

示例5: _test_auto_dataloader

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def _test_auto_dataloader(ws, nproc, batch_size, num_workers=1, sampler_name=None, dl_type=DataLoader):

    data = torch.rand(100, 3, 12, 12)

    if sampler_name is None:
        sampler = None
    elif sampler_name == "WeightedRandomSampler":
        sampler = WeightedRandomSampler(weights=torch.ones(100), num_samples=100)
    else:
        raise RuntimeError("Unknown sampler name: {}".format(sampler_name))

    # Test auto_dataloader
    assert idist.get_world_size() == ws
    dataloader = auto_dataloader(
        data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, shuffle=sampler is None
    )

    assert isinstance(dataloader, dl_type)
    if hasattr(dataloader, "_loader"):
        dataloader = dataloader._loader
    if ws < batch_size:
        assert dataloader.batch_size == batch_size // ws
    else:
        assert dataloader.batch_size == batch_size
    if ws <= num_workers:
        assert dataloader.num_workers == (num_workers + nproc - 1) // nproc
    else:
        assert dataloader.num_workers == num_workers

    if ws < 2:
        sampler_type = RandomSampler if sampler is None else type(sampler)
        assert isinstance(dataloader.sampler, sampler_type)
    else:
        sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
        assert isinstance(dataloader.sampler, sampler_type)
    if isinstance(dataloader, DataLoader):
        assert dataloader.pin_memory == ("cuda" in idist.device().type) 
开发者ID:pytorch,项目名称:ignite,代码行数:39,代码来源:test_auto.py

示例6: test_auto_methods_no_dist

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_auto_methods_no_dist():

    _test_auto_dataloader(1, 1, batch_size=1)
    _test_auto_dataloader(1, 1, batch_size=10, num_workers=10)
    _test_auto_dataloader(1, 1, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(1, "cpu") 
开发者ID:pytorch,项目名称:ignite,代码行数:9,代码来源:test_auto.py

示例7: test_auto_methods_gloo

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_auto_methods_gloo(distributed_context_single_node_gloo):

    ws = distributed_context_single_node_gloo["world_size"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, sampler_name="WeightedRandomSampler")

    _test_auto_model_optimizer(ws, "cpu") 
开发者ID:pytorch,项目名称:ignite,代码行数:10,代码来源:test_auto.py

示例8: test_auto_methods_nccl

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def test_auto_methods_nccl(distributed_context_single_node_nccl):

    ws = distributed_context_single_node_nccl["world_size"]
    lrank = distributed_context_single_node_nccl["local_rank"]
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=10, num_workers=10)
    _test_auto_dataloader(ws=ws, nproc=ws, batch_size=1, sampler_name="WeightedRandomSampler")

    device = "cuda"
    _test_auto_model_optimizer(ws, device) 
开发者ID:pytorch,项目名称:ignite,代码行数:12,代码来源:test_auto.py

示例9: fit

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def fit(self, X, y=None) -> None:
        documents, features = X.shape
        ds = CountTensorDataset(X.astype(np.float32))
        self.autoencoder = ProdLDA(
            in_dimension=features,
            hidden1_dimension=self.hidden1_dimension,
            hidden2_dimension=self.hidden2_dimension,
            topics=self.topics,
        )
        if self.cuda:
            self.autoencoder.cuda()
        ae_optimizer = Adam(
            self.autoencoder.parameters(), lr=self.lr, betas=(0.99, 0.999)
        )
        train(
            ds,
            self.autoencoder,
            cuda=self.cuda,
            validation=None,
            epochs=self.epochs,
            batch_size=self.batch_size,
            optimizer=ae_optimizer,
            sampler=WeightedRandomSampler(
                torch.ones(documents), min(documents, self.samples)
            ),
            silent=True,
            num_workers=0,  # TODO causes a bug to change this on Mac
        ) 
开发者ID:vlukiyanov,项目名称:pt-avitm,代码行数:30,代码来源:sklearn_api.py

示例10: __init__

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def __init__(self, data_source):

        lebel_freq = {}
        for idx in range(len(data_source)):
            label = data_source.items[idx]['language']
            if label in lebel_freq: lebel_freq[label] += 1
            else: lebel_freq[label] = 1

        total = float(sum(lebel_freq.values()))         
        weights = [total / lebel_freq[data_source.items[idx]['language']] for idx in range(len(data_source))]

        self._sampler = WeightedRandomSampler(weights, len(weights)) 
开发者ID:Tomiinek,项目名称:Multilingual_Text_to_Speech,代码行数:14,代码来源:samplers.py

示例11: get_train_data_source

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def get_train_data_source(ds_metainfo,
                          batch_size,
                          num_workers):
    """
    Get data source for training subset.

    Parameters
    ----------
    ds_metainfo : DatasetMetaInfo
        Dataset metainfo.
    batch_size : int
        Batch size.
    num_workers : int
        Number of background workers.

    Returns
    -------
    DataLoader
        Data source.
    """
    transform_train = ds_metainfo.train_transform(ds_metainfo=ds_metainfo)
    kwargs = ds_metainfo.dataset_class_extra_kwargs if ds_metainfo.dataset_class_extra_kwargs is not None else {}
    dataset = ds_metainfo.dataset_class(
        root=ds_metainfo.root_dir_path,
        mode="train",
        transform=transform_train,
        **kwargs)
    ds_metainfo.update_from_dataset(dataset)
    if not ds_metainfo.train_use_weighted_sampler:
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True)
    else:
        sampler = WeightedRandomSampler(
            weights=dataset.sample_weights,
            num_samples=len(dataset))
        return DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            # shuffle=True,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True) 
开发者ID:osmr,项目名称:imgclsmob,代码行数:48,代码来源:dataset_utils.py

示例12: train_valid

# 需要导入模块: from torch.utils.data import sampler [as 别名]
# 或者: from torch.utils.data.sampler import WeightedRandomSampler [as 别名]
def train_valid(data_root, name, img_enc, pc_enc, pc_dec, optimizer, scheduler, adain=True, projection=True,
                decimation=None, color_img=False, n_points=250, bs=4, lr=5e-5, weight_decay=1e-5, gamma=.3,
                milestones=(5, 8), n_epochs=10, print_freq=1000, val_freq=10000, checkpoint_folder=None):
    if decimation is not None:
        pc_dec = partial(pc_dec, decimation=decimation)

    net = PointcloudDeformNet((bs,) + (3 if color_img else 1, 224, 224), (bs, n_points, 3), img_enc, pc_enc, pc_dec,
                              adain=adain, projection=projection,
                              optimizer=lambda x: optimizer(x, lr, weight_decay=weight_decay),
                              scheduler=lambda x: scheduler(x, milestones=milestones, gamma=gamma),
                              weight_decay=None)
    print(net)

    train_data = ShapeNet(path=data_root, grayscale=not color_img, type='train', n_points=n_points)
    sampler = WeightedRandomSampler(train_data.sample_weights, len(train_data), True)
    train_loader = DataLoader(train_data, batch_size=bs, num_workers=1, collate_fn=collate, drop_last=True,
                              sampler=sampler)

    val_data = ShapeNet(path=data_root, grayscale=not color_img, type='valid', num_vals=10 * len(os.listdir(data_root)),
                        n_points=n_points)
    val_loader = DataLoader(val_data, batch_size=bs, shuffle=False, num_workers=1, collate_fn=collate, drop_last=True)

    if checkpoint_folder is None:
        mon = nnt.Monitor(name, print_freq=print_freq, num_iters=len(train_data) // bs, use_tensorboard=True)
        mon.copy_files(backup_files)

        mon.dump_rep('network', net)
        mon.dump_rep('optimizer', net.optim['optimizer'])
        if net.optim['scheduler']:
            mon.dump_rep('scheduler', net.optim['scheduler'])

        states = {
            'model_state_dict': net.state_dict(),
            'opt_state_dict': net.optim['optimizer'].state_dict()
        }
        if net.optim['scheduler']:
            states['scheduler_state_dict'] = net.optim['scheduler'].state_dict()

        mon.schedule(mon.dump, beginning=False, name='training.pt', obj=states, type='torch', keep=5)
        print('Training...')
    else:
        mon = nnt.Monitor(current_folder=checkpoint_folder, print_freq=print_freq, num_iters=len(train_data) // bs,
                          use_tensorboard=True)
        states = mon.load('training.pt', type='torch')
        mon.set_iter(mon.get_epoch() * len(train_data) // bs)

        net.load_state_dict(states['model_state_dict'])
        net.optim['optimizer'].load_state_dict(states['opt_state_dict'])
        if net.optim['scheduler']:
            net.optim['scheduler'].load_state_dict(states['scheduler_state_dict'])

        print('Resume from epoch %d...' % mon.get_epoch())

    mon.run_training(net, train_loader, n_epochs, val_loader, valid_freq=val_freq, reduce='mean')
    print('Training finished!') 
开发者ID:justanhduc,项目名称:graphx-conv,代码行数:57,代码来源:train.py


注:本文中的torch.utils.data.sampler.WeightedRandomSampler方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。