当前位置: 首页>>代码示例>>Python>>正文


Python collate.default_collate方法代码示例

本文整理汇总了Python中torch.utils.data._utils.collate.default_collate方法的典型用法代码示例。如果您正苦于以下问题:Python collate.default_collate方法的具体用法?Python collate.default_collate怎么用?Python collate.default_collate使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.utils.data._utils.collate的用法示例。


在下文中一共展示了collate.default_collate方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: __init__

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def __init__(self, dataset, **kwargs):
        # drop_last is handled transparently by _SafeDataLoaderIter (bypassing
        # DataLoader). Since drop_last cannot be changed after initializing the
        # DataLoader instance, it needs to be intercepted here.
        assert isinstance(
            dataset, SafeDataset
        ), "dataset must be an instance of SafeDataset."

        self.drop_last_original = False
        if "drop_last" in kwargs:
            self.drop_last_original = kwargs["drop_last"]
            kwargs["drop_last"] = False
        super(SafeDataLoader, self).__init__(dataset, **kwargs)

        self.safe_dataset = self.dataset
        self.dataset = _OriginalDataset(self.safe_dataset)

        if self.collate_fn is default_collate:
            self.collate_fn = SafeDataLoader._safe_default_collate 
开发者ID:msamogh,项目名称:nonechucks,代码行数:21,代码来源:dataloader.py

示例2: list_data_collate

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def list_data_collate(batch):
    """
    Enhancement for PyTorch DataLoader default collate.
    If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list.
    Then it's same as the default collate behavior.

    Note:
        Need to use this collate if apply some transforms that can generate batch data.

    """
    elem = batch[0]
    data = [i for k in batch for i in k] if isinstance(elem, list) else batch
    return default_collate(data) 
开发者ID:Project-MONAI,项目名称:MONAI,代码行数:15,代码来源:utils.py

示例3: mixup_collate

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def mixup_collate(data, alpha=0.1):
    """Implements a batch collate function with MixUp strategy from
    `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_

    Args:
        data (list): list of elements
        alpha (float, optional): mixup factor

    Example::
        >>> import torch
        >>> from holocron import utils
        >>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=utils.data.mixup_collate)
    """

    inputs, targets = default_collate(data)

    # Sample lambda
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    # Mix batch indices
    batch_size = inputs.size()[0]
    index = torch.randperm(batch_size)

    # Create the new input and targets
    inputs = lam * inputs + (1 - lam) * inputs[index, :]
    targets_a, targets_b = targets, targets[index]

    return inputs, targets_a, targets_b, lam 
开发者ID:frgfm,项目名称:Holocron,代码行数:33,代码来源:collate.py

示例4: collate

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def collate(batch, *, root=True):
    "Puts each data field into a tensor with outer dimension batch size"

    if len(batch) == 0:
        return batch

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if torch.is_tensor(batch[0]):
        return default_collate(batch)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        return default_collate(batch)
    elif isinstance(batch[0], int_classes):
        return batch
    elif isinstance(batch[0], float):
        return batch
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], CameraIntrinsics):
        return batch
    elif isinstance(batch[0], Mapping):
        if root:
            return {key: collate([d[key] for d in batch], root=False) for key in batch[0]}
        else:
            return batch
    elif isinstance(batch[0], Sequence):
        return [collate(e, root=False) for e in batch]

    raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:anibali,项目名称:margipose,代码行数:32,代码来源:__init__.py

示例5: detection_collate

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def detection_collate(batch):
    """
    Collate function for detection task. Concatanate bboxes, labels and
    metadata from different samples in the first dimension instead of
    stacking them to have a batch-size dimension.
    Args:
        batch (tuple or list): data batch to collate.
    Returns:
        (tuple): collated detection data batch.
    """
    inputs, labels, video_idx, extra_data = zip(*batch)
    inputs, video_idx = default_collate(inputs), default_collate(video_idx)
    labels = torch.tensor(np.concatenate(labels, axis=0)).float()

    collated_extra_data = {}
    for key in extra_data[0].keys():
        data = [d[key] for d in extra_data]
        if key == "boxes" or key == "ori_boxes":
            # Append idx info to the bboxes before concatenating them.
            bboxes = [
                np.concatenate(
                    [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
                )
                for i in range(len(data))
            ]
            bboxes = np.concatenate(bboxes, axis=0)
            collated_extra_data[key] = torch.tensor(bboxes).float()
        elif key == "metadata":
            collated_extra_data[key] = torch.tensor(
                list(itertools.chain(*data))
            ).view(-1, 2)
        else:
            collated_extra_data[key] = default_collate(data)

    return inputs, labels, video_idx, collated_extra_data 
开发者ID:facebookresearch,项目名称:SlowFast,代码行数:37,代码来源:loader.py

示例6: collate_batches

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def collate_batches(batches, collate_fn=default_collate):
    """Collate multiple batches."""
    error_msg = "batches must be tensors, dicts, or lists; found {}"
    if isinstance(batches[0], torch.Tensor):
        return torch.cat(batches, 0)
    elif isinstance(batches[0], collections.Sequence):
        return list(chain(*batches))
    elif isinstance(batches[0], collections.Mapping):
        return {key: default_collate([d[key] for d in batches]) for key in batches[0]}
    raise TypeError((error_msg.format(type(batches[0])))) 
开发者ID:msamogh,项目名称:nonechucks,代码行数:12,代码来源:utils.py

示例7: _safe_default_collate

# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def _safe_default_collate(batch):
        filtered_batch = [x for x in batch if x is not None]
        if len(filtered_batch) == 0:
            return []
        return default_collate(filtered_batch) 
开发者ID:msamogh,项目名称:nonechucks,代码行数:7,代码来源:dataloader.py


注:本文中的torch.utils.data._utils.collate.default_collate方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。