當前位置: 首頁>>代碼示例>>Python>>正文


Python dataloader.default_collate方法代碼示例

本文整理匯總了Python中torch.utils.data.dataloader.default_collate方法的典型用法代碼示例。如果您正苦於以下問題:Python dataloader.default_collate方法的具體用法?Python dataloader.default_collate怎麽用?Python dataloader.default_collate使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data.dataloader的用法示例。


在下文中一共展示了dataloader.default_collate方法的15個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: ctc_collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def ctc_collate(batch):
    '''
    Stack samples into CTC style inputs.
    Modified based on default_collate() in PyTorch.
    By Yuan-Hang Zhang.
    '''
    xs, ys, lens, indices = zip(*batch)
    max_len = max(lens)
    x = default_collate(xs)
    x.narrow(2, 0, max_len)
    y = []
    for sub in ys: y += sub
    y = torch.IntTensor(y)
    lengths = torch.IntTensor(lens)
    y_lengths = torch.IntTensor([len(label) for label in ys])
    ids = default_collate(indices)

    return x, y, lengths, y_lengths, ids 
開發者ID:sailordiary,項目名稱:LipNet-PyTorch,代碼行數:20,代碼來源:dataloader.py

示例2: collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def collate(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int_classes):
            return torch.tensor(batch)
        elif isinstance(elem, string_classes):
            return batch
        elif isinstance(elem, container_abcs.Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(s) for s in zip(*batch)))
        elif isinstance(elem, container_abcs.Sequence):
            return [self.collate(s) for s in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:22,代碼來源:dataloader.py

示例3: collater

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def collater(self, samples):
        """Merge a list of samples to form a mini-batch.

        Args:
            samples (List[dict]): samples to collate

        Returns:
            dict: a mini-batch suitable for forwarding with a Model
        """
        if len(samples) == 0:
            return {}
        sample = OrderedDict()
        for k, ds in self.defn.items():
            try:
                sample[k] = ds.collater([s[k] for s in samples])
            except NotImplementedError:
                sample[k] = default_collate([s[k] for s in samples])
        return _unflatten(sample) 
開發者ID:pytorch,項目名稱:fairseq,代碼行數:20,代碼來源:nested_dictionary_dataset.py

示例4: ava_collate_fn

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def ava_collate_fn(batch):
    "Pads data and puts it into a tensor of same dimensions"
    max_len = 0
    for b in batch:
        if b[0].shape[0] > max_len:
            max_len = b[0].shape[0]

    new_batch = []
    for b in batch:
        f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32)
        m = np.zeros((max_len), np.float32)
        l = np.zeros((max_len, b[1].shape[1]), np.float32)
        f[:b[0].shape[0]] = b[0]
        m[:b[0].shape[0]] = 1
        l[:b[0].shape[0], :] = b[1]
        new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]])

    return default_collate(new_batch) 
開發者ID:piergiaj,項目名稱:super-events-cvpr18,代碼行數:20,代碼來源:ava_i3d_per_video.py

示例5: mt_collate_fn

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def mt_collate_fn(batch):
    "Pads data and puts it into a tensor of same dimensions"
    max_len = 0
    for b in batch:
        if b[0].shape[0] > max_len:
            max_len = b[0].shape[0]

    new_batch = []
    for b in batch:
        f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32)
        m = np.zeros((max_len), np.float32)
        l = np.zeros((max_len, b[1].shape[1]), np.float32)
        f[:b[0].shape[0]] = b[0]
        m[:b[0].shape[0]] = 1
        l[:b[0].shape[0], :] = b[1]
        new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]])

    return default_collate(new_batch) 
開發者ID:piergiaj,項目名稱:super-events-cvpr18,代碼行數:20,代碼來源:charades_i3d_per_video.py

示例6: __init__

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle

        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
        }
        super().__init__(sampler=self.sampler, **self.init_kwargs) 
開發者ID:victoresque,項目名稱:pytorch-template,代碼行數:19,代碼來源:base_data_loader.py

示例7: make_collate_fn

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def make_collate_fn(padding_values):

    def _collate_fn(batch):

        for name, padding_value in padding_values.items():

            lengths = [len(sample[name]) for sample in batch]
            max_length = max(lengths)

            for n, size in enumerate(lengths):
                p = max_length - size
                if p:
                    pad_width = [(0, p)] + [(0, 0)] * (batch[n][name].ndim - 1)
                    if padding_value == "edge":
                        batch[n][name] = np.pad(
                            batch[n][name], pad_width,
                            mode="edge")
                    else:
                        batch[n][name] = np.pad(
                            batch[n][name], pad_width,
                            mode="constant", constant_values=padding_value)

        return default_collate(batch)

    return _collate_fn 
開發者ID:ex4sperans,項目名稱:freesound-classification,代碼行數:27,代碼來源:padding.py

示例8: detection_collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def detection_collate(batch):
    """Custom collate fn for dealing with batches of images that have a different
    number of associated object annotations (bounding boxes).

    Arguments:
        batch: (tuple) A tuple of tensor images and lists of annotations

    Return:
        A tuple containing:
            1) (tensor) batch of images stacked on their 0 dim
            2) (list of tensors) annotations for a given image are stacked on 0 dim
    """
    custom = defaultdict(list)
    custom_keys = ['target_size', ]
    for sample in batch:
        for k in custom_keys:
            custom[k] += [sample[k]]
    other = {k: default_collate([b[k] for b in batch]) for k in
             filter(lambda x: x not in custom, batch[0].keys())}
    return {**other, **custom} 
開發者ID:orsic,項目名稱:swiftnet,代碼行數:22,代碼來源:base.py

示例9: custom_collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def custom_collate(batch, del_orig_labels=False):
    keys = ['target_size', 'target_size_feats', 'alphas', 'target_level']
    values = {}
    for k in keys:
        if k in batch[0]:
            values[k] = batch[0][k]
    for b in batch:
        if del_orig_labels: del b['original_labels']
        for k in values.keys():
            del b[k]
        if 'mux_indices' in b:
            b['mux_indices'] = b['mux_indices'].view(-1)
    batch = default_collate(batch)
    # if 'image_next' in batch:
    #     batch['image'] = torch.cat([batch['image'], batch['image_next']], dim=0).contiguous()
    #     del batch['image_next']
    for k, v in values.items():
        batch[k] = v
    return batch 
開發者ID:orsic,項目名稱:swiftnet,代碼行數:21,代碼來源:base.py

示例10: collate_minibatch

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def collate_minibatch(list_of_blobs):
    """Stack samples seperately and return a list of minibatches
    A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
    Hence, we need to stack smaples from each minibatch seperately.
    """
    Batch = {key: [] for key in list_of_blobs[0]}
    # Because roidb consists of entries of variable length, it can't be batch into a tensor.
    # So we keep roidb in the type of "list of ndarray".
    list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs]
    for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
        mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
        # Pad image data
        mini_list = pad_image_data(mini_list)
        minibatch = default_collate(mini_list)
        minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
        for key in minibatch:
            Batch[key].append(minibatch[key])

    return Batch 
開發者ID:roytseng-tw,項目名稱:Detectron.pytorch,代碼行數:21,代碼來源:loader.py

示例11: decimal_friendly_collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def decimal_friendly_collate(batch):
    """A wrapper on top of ``default_collate`` function that allows decimal.Decimal types to be collated.

    We use ``decimal.Decimal`` types in petastorm dataset to represent timestamps. PyTorch's ``default_collate``
    implementation does not support collating ``decimal.Decimal`` types. ``decimal_friendly_collate`` collates
    ``decimal.Decimal`` separately and then combines with the rest of the fields collated by a standard
    ``default_collate``.

    :param batch: A list of dictionaries to collate
    :return: A dictionary of lists/pytorch.Tensor types
    """

    if isinstance(batch[0], decimal.Decimal):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], _string_classes):
        return batch
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [decimal_friendly_collate(samples) for samples in transposed]
    else:
        return default_collate(batch) 
開發者ID:uber,項目名稱:petastorm,代碼行數:25,代碼來源:pytorch.py

示例12: __init__

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers,
     collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle
        
        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
            }
        super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 
開發者ID:daili0015,項目名稱:ModelFeast,代碼行數:20,代碼來源:base_data_loader.py

示例13: collate_minibatch

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def collate_minibatch(list_of_blobs):
    """Stack samples seperately and return a list of minibatches
    A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
    Hence, we need to stack smaples from each minibatch seperately.
    """
    Batch = {key: [] for key in list_of_blobs[0]}
    # Because roidb consists of entries of variable length, it can't be batch into a tensor.
    # So we keep roidb in the type of "list of ndarray".
    lists = []
    for blobs in list_of_blobs:
        lists.append({'data' : blobs.pop('data'),
                      'rois' : blobs.pop('rois'),
                      'labels' : blobs.pop('labels')})
    for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
        mini_list = lists[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
        minibatch = default_collate(mini_list)
        for key in minibatch:
            Batch[key].append(minibatch[key])

    return Batch 
開發者ID:ppengtang,項目名稱:pcl.pytorch,代碼行數:22,代碼來源:loader.py

示例14: __init__

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle

        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
        }
        super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 
開發者ID:yjlolo,項目名稱:vae-audio,代碼行數:19,代碼來源:base_data_loader.py

示例15: __init__

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import default_collate [as 別名]
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, base_seed=None, worker_init_fn=None, worker_init_args=None, worker_init_kwargs=None,
                 worker_recv_fn=None, **kwargs):

        worker_init_args = worker_init_args if worker_init_args is not None else [tuple() for _ in range(num_workers)]
        worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else [{} for _ in range(num_workers)]

        base_seed = base_seed if base_seed is not None else gen_seed()
        self.worker_recv_fn = worker_recv_fn
        if worker_recv_fn is not None:
            self.pipe_master = DataLoaderPipeMaster(num_workers)
        else:
            self.pipe_master = None

        worker_init_fn = _InitFunctionWrapper(
            base_seed, worker_init_fn, worker_init_args, worker_init_kwargs,
            self.pipe_master, DataLoaderPipeSlave(worker_recv_fn)
        )
        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
                         num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last,
                         timeout=timeout, worker_init_fn=worker_init_fn, **kwargs) 
開發者ID:vacancy,項目名稱:Jacinle,代碼行數:24,代碼來源:dataloader.py


注:本文中的torch.utils.data.dataloader.default_collate方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。