当前位置: 首页>>代码示例>>Python>>正文


Python dataloader.default_collate方法代码示例

本文整理汇总了Python中torch.utils.data.dataloader.default_collate方法的典型用法代码示例。如果您正苦于以下问题:Python dataloader.default_collate方法的具体用法?Python dataloader.default_collate怎么用?Python dataloader.default_collate使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data.dataloader的用法示例。


在下文中一共展示了dataloader.default_collate方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: ctc_collate

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def ctc_collate(batch):
    '''
    Stack samples into CTC style inputs.
    Modified based on default_collate() in PyTorch.
    By Yuan-Hang Zhang.
    '''
    xs, ys, lens, indices = zip(*batch)
    max_len = max(lens)
    x = default_collate(xs)
    x.narrow(2, 0, max_len)
    y = []
    for sub in ys: y += sub
    y = torch.IntTensor(y)
    lengths = torch.IntTensor(lens)
    y_lengths = torch.IntTensor([len(label) for label in ys])
    ids = default_collate(indices)

    return x, y, lengths, y_lengths, ids 
开发者ID:sailordiary,项目名称:LipNet-PyTorch,代码行数:20,代码来源:dataloader.py

示例2: collate

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collate(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int_classes):
            return torch.tensor(batch)
        elif isinstance(elem, string_classes):
            return batch
        elif isinstance(elem, container_abcs.Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(s) for s in zip(*batch)))
        elif isinstance(elem, container_abcs.Sequence):
            return [self.collate(s) for s in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:22,代码来源:dataloader.py

示例3: collater

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch suitable for forwarding with a Model
        """
        if len(samples) == 0:
            return {}
        sample = OrderedDict()
        for k, ds in self.defn.items():
            try:
                sample[k] = ds.collater([s[k] for s in samples])
            except NotImplementedError:
                sample[k] = default_collate([s[k] for s in samples])
        return _unflatten(sample) 
开发者ID:pytorch,项目名称:fairseq,代码行数:20,代码来源:nested_dictionary_dataset.py

示例4: ava_collate_fn

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def ava_collate_fn(batch):
    "Pads data and puts it into a tensor of same dimensions"
    max_len = 0
    for b in batch:
        if b[0].shape[0] > max_len:
            max_len = b[0].shape[0]

    new_batch = []
    for b in batch:
        f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32)
        m = np.zeros((max_len), np.float32)
        l = np.zeros((max_len, b[1].shape[1]), np.float32)
        f[:b[0].shape[0]] = b[0]
        m[:b[0].shape[0]] = 1
        l[:b[0].shape[0], :] = b[1]
        new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]])

    return default_collate(new_batch) 
开发者ID:piergiaj,项目名称:super-events-cvpr18,代码行数:20,代码来源:ava_i3d_per_video.py

示例5: mt_collate_fn

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def mt_collate_fn(batch):
    "Pads data and puts it into a tensor of same dimensions"
    max_len = 0
    for b in batch:
        if b[0].shape[0] > max_len:
            max_len = b[0].shape[0]

    new_batch = []
    for b in batch:
        f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32)
        m = np.zeros((max_len), np.float32)
        l = np.zeros((max_len, b[1].shape[1]), np.float32)
        f[:b[0].shape[0]] = b[0]
        m[:b[0].shape[0]] = 1
        l[:b[0].shape[0], :] = b[1]
        new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]])

    return default_collate(new_batch) 
开发者ID:piergiaj,项目名称:super-events-cvpr18,代码行数:20,代码来源:charades_i3d_per_video.py

示例6: __init__

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle

        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
        }
        super().__init__(sampler=self.sampler, **self.init_kwargs) 
开发者ID:victoresque,项目名称:pytorch-template,代码行数:19,代码来源:base_data_loader.py

示例7: make_collate_fn

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def make_collate_fn(padding_values):

    def _collate_fn(batch):

        for name, padding_value in padding_values.items():

            lengths = [len(sample[name]) for sample in batch]
            max_length = max(lengths)

            for n, size in enumerate(lengths):
                p = max_length - size
                if p:
                    pad_width = [(0, p)] + [(0, 0)] * (batch[n][name].ndim - 1)
                    if padding_value == "edge":
                        batch[n][name] = np.pad(
                            batch[n][name], pad_width,
                            mode="edge")
                    else:
                        batch[n][name] = np.pad(
                            batch[n][name], pad_width,
                            mode="constant", constant_values=padding_value)

        return default_collate(batch)

    return _collate_fn 
开发者ID:ex4sperans,项目名称:freesound-classification,代码行数:27,代码来源:padding.py

示例8: detection_collate

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    custom = defaultdict(list)
    custom_keys = ['target_size', ]
    for sample in batch:
        for k in custom_keys:
            custom[k] += [sample[k]]
    other = {k: default_collate([b[k] for b in batch]) for k in
             filter(lambda x: x not in custom, batch[0].keys())}
    return {**other, **custom} 
开发者ID:orsic,项目名称:swiftnet,代码行数:22,代码来源:base.py

示例9: custom_collate

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def custom_collate(batch, del_orig_labels=False):
    keys = ['target_size', 'target_size_feats', 'alphas', 'target_level']
    values = {}
    for k in keys:
        if k in batch[0]:
            values[k] = batch[0][k]
    for b in batch:
        if del_orig_labels: del b['original_labels']
        for k in values.keys():
            del b[k]
        if 'mux_indices' in b:
            b['mux_indices'] = b['mux_indices'].view(-1)
    batch = default_collate(batch)
    # if 'image_next' in batch:
    #     batch['image'] = torch.cat([batch['image'], batch['image_next']], dim=0).contiguous()
    #     del batch['image_next']
    for k, v in values.items():
        batch[k] = v
    return batch 
开发者ID:orsic,项目名称:swiftnet,代码行数:21,代码来源:base.py

示例10: collate_minibatch

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collate_minibatch(list_of_blobs):
    """Stack samples seperately and return a list of minibatches
    A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
    Hence, we need to stack smaples from each minibatch seperately.
    """
    Batch = {key: [] for key in list_of_blobs[0]}
    # Because roidb consists of entries of variable length, it can't be batch into a tensor.
    # So we keep roidb in the type of "list of ndarray".
    list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs]
    for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
        mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
        # Pad image data
        mini_list = pad_image_data(mini_list)
        minibatch = default_collate(mini_list)
        minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
        for key in minibatch:
            Batch[key].append(minibatch[key])

    return Batch 
开发者ID:roytseng-tw,项目名称:Detectron.pytorch,代码行数:21,代码来源:loader.py

示例11: decimal_friendly_collate

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def decimal_friendly_collate(batch):
    """A wrapper on top of ``default_collate`` function that allows decimal.Decimal types to be collated.

    We use ``decimal.Decimal`` types in petastorm dataset to represent timestamps. PyTorch's ``default_collate``
    implementation does not support collating ``decimal.Decimal`` types. ``decimal_friendly_collate`` collates
    ``decimal.Decimal`` separately and then combines with the rest of the fields collated by a standard
    ``default_collate``.

    :param batch: A list of dictionaries to collate
    :return: A dictionary of lists/pytorch.Tensor types
    """

    if isinstance(batch[0], decimal.Decimal):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], _string_classes):
        return batch
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [decimal_friendly_collate(samples) for samples in transposed]
    else:
        return default_collate(batch) 
开发者ID:uber,项目名称:petastorm,代码行数:25,代码来源:pytorch.py

示例12: __init__

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers,
     collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle
        
        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
            }
        super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 
开发者ID:daili0015,项目名称:ModelFeast,代码行数:20,代码来源:base_data_loader.py

示例13: collate_minibatch

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collate_minibatch(list_of_blobs):
    """Stack samples seperately and return a list of minibatches
    A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
    Hence, we need to stack smaples from each minibatch seperately.
    """
    Batch = {key: [] for key in list_of_blobs[0]}
    # Because roidb consists of entries of variable length, it can't be batch into a tensor.
    # So we keep roidb in the type of "list of ndarray".
    lists = []
    for blobs in list_of_blobs:
        lists.append({'data' : blobs.pop('data'),
                      'rois' : blobs.pop('rois'),
                      'labels' : blobs.pop('labels')})
    for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
        mini_list = lists[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
        minibatch = default_collate(mini_list)
        for key in minibatch:
            Batch[key].append(minibatch[key])

    return Batch 
开发者ID:ppengtang,项目名称:pcl.pytorch,代码行数:22,代码来源:loader.py

示例14: __init__

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle

        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
        }
        super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 
开发者ID:yjlolo,项目名称:vae-audio,代码行数:19,代码来源:base_data_loader.py

示例15: __init__

# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, base_seed=None, worker_init_fn=None, worker_init_args=None, worker_init_kwargs=None,
                 worker_recv_fn=None, **kwargs):

        worker_init_args = worker_init_args if worker_init_args is not None else [tuple() for _ in range(num_workers)]
        worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else [{} for _ in range(num_workers)]

        base_seed = base_seed if base_seed is not None else gen_seed()
        self.worker_recv_fn = worker_recv_fn
        if worker_recv_fn is not None:
            self.pipe_master = DataLoaderPipeMaster(num_workers)
        else:
            self.pipe_master = None

        worker_init_fn = _InitFunctionWrapper(
            base_seed, worker_init_fn, worker_init_args, worker_init_kwargs,
            self.pipe_master, DataLoaderPipeSlave(worker_recv_fn)
        )
        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
                         num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last,
                         timeout=timeout, worker_init_fn=worker_init_fn, **kwargs) 
开发者ID:vacancy,项目名称:Jacinle,代码行数:24,代码来源:dataloader.py


注:本文中的torch.utils.data.dataloader.default_collate方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。