當前位置: 首頁>>代碼示例>>Python>>正文


Python collate.default_collate方法代碼示例

本文整理匯總了Python中torch.utils.data._utils.collate.default_collate方法的典型用法代碼示例。如果您正苦於以下問題:Python collate.default_collate方法的具體用法?Python collate.default_collate怎麽用?Python collate.default_collate使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data._utils.collate的用法示例。


在下文中一共展示了collate.default_collate方法的7個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def __init__(self, dataset, **kwargs):
        # drop_last is handled transparently by _SafeDataLoaderIter (bypassing
        # DataLoader). Since drop_last cannot be changed after initializing the
        # DataLoader instance, it needs to be intercepted here.
        assert isinstance(
            dataset, SafeDataset
        ), "dataset must be an instance of SafeDataset."

        self.drop_last_original = False
        if "drop_last" in kwargs:
            self.drop_last_original = kwargs["drop_last"]
            kwargs["drop_last"] = False
        super(SafeDataLoader, self).__init__(dataset, **kwargs)

        self.safe_dataset = self.dataset
        self.dataset = _OriginalDataset(self.safe_dataset)

        if self.collate_fn is default_collate:
            self.collate_fn = SafeDataLoader._safe_default_collate 
開發者ID:msamogh,項目名稱:nonechucks,代碼行數:21,代碼來源:dataloader.py

示例2: list_data_collate

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def list_data_collate(batch):
    """
    Enhancement for PyTorch DataLoader default collate.
    If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list.
    Then it's same as the default collate behavior.

    Note:
        Need to use this collate if apply some transforms that can generate batch data.

    """
    elem = batch[0]
    data = [i for k in batch for i in k] if isinstance(elem, list) else batch
    return default_collate(data) 
開發者ID:Project-MONAI,項目名稱:MONAI,代碼行數:15,代碼來源:utils.py

示例3: mixup_collate

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def mixup_collate(data, alpha=0.1):
    """Implements a batch collate function with MixUp strategy from
    `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_

    Args:
        data (list): list of elements
        alpha (float, optional): mixup factor

    Example::
        >>> import torch
        >>> from holocron import utils
        >>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=utils.data.mixup_collate)
    """

    inputs, targets = default_collate(data)

    # Sample lambda
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    # Mix batch indices
    batch_size = inputs.size()[0]
    index = torch.randperm(batch_size)

    # Create the new input and targets
    inputs = lam * inputs + (1 - lam) * inputs[index, :]
    targets_a, targets_b = targets, targets[index]

    return inputs, targets_a, targets_b, lam 
開發者ID:frgfm,項目名稱:Holocron,代碼行數:33,代碼來源:collate.py

示例4: collate

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def collate(batch, *, root=True):
    "Puts each data field into a tensor with outer dimension batch size"

    if len(batch) == 0:
        return batch

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if torch.is_tensor(batch[0]):
        return default_collate(batch)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        return default_collate(batch)
    elif isinstance(batch[0], int_classes):
        return batch
    elif isinstance(batch[0], float):
        return batch
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], CameraIntrinsics):
        return batch
    elif isinstance(batch[0], Mapping):
        if root:
            return {key: collate([d[key] for d in batch], root=False) for key in batch[0]}
        else:
            return batch
    elif isinstance(batch[0], Sequence):
        return [collate(e, root=False) for e in batch]

    raise TypeError((error_msg.format(type(batch[0])))) 
開發者ID:anibali,項目名稱:margipose,代碼行數:32,代碼來源:__init__.py

示例5: detection_collate

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def detection_collate(batch):
    """
    Collate function for detection task. Concatanate bboxes, labels and
    metadata from different samples in the first dimension instead of
    stacking them to have a batch-size dimension.
    Args:
        batch (tuple or list): data batch to collate.
    Returns:
        (tuple): collated detection data batch.
    """
    inputs, labels, video_idx, extra_data = zip(*batch)
    inputs, video_idx = default_collate(inputs), default_collate(video_idx)
    labels = torch.tensor(np.concatenate(labels, axis=0)).float()

    collated_extra_data = {}
    for key in extra_data[0].keys():
        data = [d[key] for d in extra_data]
        if key == "boxes" or key == "ori_boxes":
            # Append idx info to the bboxes before concatenating them.
            bboxes = [
                np.concatenate(
                    [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
                )
                for i in range(len(data))
            ]
            bboxes = np.concatenate(bboxes, axis=0)
            collated_extra_data[key] = torch.tensor(bboxes).float()
        elif key == "metadata":
            collated_extra_data[key] = torch.tensor(
                list(itertools.chain(*data))
            ).view(-1, 2)
        else:
            collated_extra_data[key] = default_collate(data)

    return inputs, labels, video_idx, collated_extra_data 
開發者ID:facebookresearch,項目名稱:SlowFast,代碼行數:37,代碼來源:loader.py

示例6: collate_batches

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def collate_batches(batches, collate_fn=default_collate):
    """Collate multiple batches."""
    error_msg = "batches must be tensors, dicts, or lists; found {}"
    if isinstance(batches[0], torch.Tensor):
        return torch.cat(batches, 0)
    elif isinstance(batches[0], collections.Sequence):
        return list(chain(*batches))
    elif isinstance(batches[0], collections.Mapping):
        return {key: default_collate([d[key] for d in batches]) for key in batches[0]}
    raise TypeError((error_msg.format(type(batches[0])))) 
開發者ID:msamogh,項目名稱:nonechucks,代碼行數:12,代碼來源:utils.py

示例7: _safe_default_collate

# 需要導入模塊: from torch.utils.data._utils import collate [as 別名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 別名]
def _safe_default_collate(batch):
        filtered_batch = [x for x in batch if x is not None]
        if len(filtered_batch) == 0:
            return []
        return default_collate(filtered_batch) 
開發者ID:msamogh,項目名稱:nonechucks,代碼行數:7,代碼來源:dataloader.py


注:本文中的torch.utils.data._utils.collate.default_collate方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。