本文整理汇总了Python中torch.utils.data._utils.collate.default_collate方法的典型用法代码示例。如果您正苦于以下问题:Python collate.default_collate方法的具体用法?Python collate.default_collate怎么用?Python collate.default_collate使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data._utils.collate
的用法示例。
在下文中一共展示了collate.default_collate方法的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: __init__
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def __init__(self, dataset, **kwargs):
# drop_last is handled transparently by _SafeDataLoaderIter (bypassing
# DataLoader). Since drop_last cannot be changed after initializing the
# DataLoader instance, it needs to be intercepted here.
assert isinstance(
dataset, SafeDataset
), "dataset must be an instance of SafeDataset."
self.drop_last_original = False
if "drop_last" in kwargs:
self.drop_last_original = kwargs["drop_last"]
kwargs["drop_last"] = False
super(SafeDataLoader, self).__init__(dataset, **kwargs)
self.safe_dataset = self.dataset
self.dataset = _OriginalDataset(self.safe_dataset)
if self.collate_fn is default_collate:
self.collate_fn = SafeDataLoader._safe_default_collate
示例2: list_data_collate
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def list_data_collate(batch):
"""
Enhancement for PyTorch DataLoader default collate.
If dataset already returns a list of batch data that generated in transforms, need to merge all data to 1 list.
Then it's same as the default collate behavior.
Note:
Need to use this collate if apply some transforms that can generate batch data.
"""
elem = batch[0]
data = [i for k in batch for i in k] if isinstance(elem, list) else batch
return default_collate(data)
示例3: mixup_collate
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def mixup_collate(data, alpha=0.1):
"""Implements a batch collate function with MixUp strategy from
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/pdf/1710.09412.pdf>`_
Args:
data (list): list of elements
alpha (float, optional): mixup factor
Example::
>>> import torch
>>> from holocron import utils
>>> loader = torch.utils.data.DataLoader(dataset, batch_size, collate_fn=utils.data.mixup_collate)
"""
inputs, targets = default_collate(data)
# Sample lambda
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
# Mix batch indices
batch_size = inputs.size()[0]
index = torch.randperm(batch_size)
# Create the new input and targets
inputs = lam * inputs + (1 - lam) * inputs[index, :]
targets_a, targets_b = targets, targets[index]
return inputs, targets_a, targets_b, lam
示例4: collate
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def collate(batch, *, root=True):
"Puts each data field into a tensor with outer dimension batch size"
if len(batch) == 0:
return batch
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if torch.is_tensor(batch[0]):
return default_collate(batch)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
return default_collate(batch)
elif isinstance(batch[0], int_classes):
return batch
elif isinstance(batch[0], float):
return batch
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], CameraIntrinsics):
return batch
elif isinstance(batch[0], Mapping):
if root:
return {key: collate([d[key] for d in batch], root=False) for key in batch[0]}
else:
return batch
elif isinstance(batch[0], Sequence):
return [collate(e, root=False) for e in batch]
raise TypeError((error_msg.format(type(batch[0]))))
示例5: detection_collate
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def detection_collate(batch):
"""
Collate function for detection task. Concatanate bboxes, labels and
metadata from different samples in the first dimension instead of
stacking them to have a batch-size dimension.
Args:
batch (tuple or list): data batch to collate.
Returns:
(tuple): collated detection data batch.
"""
inputs, labels, video_idx, extra_data = zip(*batch)
inputs, video_idx = default_collate(inputs), default_collate(video_idx)
labels = torch.tensor(np.concatenate(labels, axis=0)).float()
collated_extra_data = {}
for key in extra_data[0].keys():
data = [d[key] for d in extra_data]
if key == "boxes" or key == "ori_boxes":
# Append idx info to the bboxes before concatenating them.
bboxes = [
np.concatenate(
[np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
)
for i in range(len(data))
]
bboxes = np.concatenate(bboxes, axis=0)
collated_extra_data[key] = torch.tensor(bboxes).float()
elif key == "metadata":
collated_extra_data[key] = torch.tensor(
list(itertools.chain(*data))
).view(-1, 2)
else:
collated_extra_data[key] = default_collate(data)
return inputs, labels, video_idx, collated_extra_data
示例6: collate_batches
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def collate_batches(batches, collate_fn=default_collate):
"""Collate multiple batches."""
error_msg = "batches must be tensors, dicts, or lists; found {}"
if isinstance(batches[0], torch.Tensor):
return torch.cat(batches, 0)
elif isinstance(batches[0], collections.Sequence):
return list(chain(*batches))
elif isinstance(batches[0], collections.Mapping):
return {key: default_collate([d[key] for d in batches]) for key in batches[0]}
raise TypeError((error_msg.format(type(batches[0]))))
示例7: _safe_default_collate
# 需要导入模块: from torch.utils.data._utils import collate [as 别名]
# 或者: from torch.utils.data._utils.collate import default_collate [as 别名]
def _safe_default_collate(batch):
filtered_batch = [x for x in batch if x is not None]
if len(filtered_batch) == 0:
return []
return default_collate(filtered_batch)