本文整理汇总了Python中torch.utils.data.dataloader.default_collate方法的典型用法代码示例。如果您正苦于以下问题:Python dataloader.default_collate方法的具体用法?Python dataloader.default_collate怎么用?Python dataloader.default_collate使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.utils.data.dataloader
的用法示例。
在下文中一共展示了dataloader.default_collate方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: ctc_collate
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def ctc_collate(batch):
'''
Stack samples into CTC style inputs.
Modified based on default_collate() in PyTorch.
By Yuan-Hang Zhang.
'''
xs, ys, lens, indices = zip(*batch)
max_len = max(lens)
x = default_collate(xs)
x.narrow(2, 0, max_len)
y = []
for sub in ys: y += sub
y = torch.IntTensor(y)
lengths = torch.IntTensor(lens)
y_lengths = torch.IntTensor([len(label) for label in ys])
ids = default_collate(indices)
return x, y, lengths, y_lengths, ids
示例2: collate
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collate(self, batch):
elem = batch[0]
if isinstance(elem, Data):
return Batch.from_data_list(batch, self.follow_batch)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: self.collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(self.collate(s) for s in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
return [self.collate(s) for s in zip(*batch)]
raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
示例3: collater
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
if len(samples) == 0:
return {}
sample = OrderedDict()
for k, ds in self.defn.items():
try:
sample[k] = ds.collater([s[k] for s in samples])
except NotImplementedError:
sample[k] = default_collate([s[k] for s in samples])
return _unflatten(sample)
示例4: ava_collate_fn
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def ava_collate_fn(batch):
"Pads data and puts it into a tensor of same dimensions"
max_len = 0
for b in batch:
if b[0].shape[0] > max_len:
max_len = b[0].shape[0]
new_batch = []
for b in batch:
f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32)
m = np.zeros((max_len), np.float32)
l = np.zeros((max_len, b[1].shape[1]), np.float32)
f[:b[0].shape[0]] = b[0]
m[:b[0].shape[0]] = 1
l[:b[0].shape[0], :] = b[1]
new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]])
return default_collate(new_batch)
示例5: mt_collate_fn
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def mt_collate_fn(batch):
"Pads data and puts it into a tensor of same dimensions"
max_len = 0
for b in batch:
if b[0].shape[0] > max_len:
max_len = b[0].shape[0]
new_batch = []
for b in batch:
f = np.zeros((max_len, b[0].shape[1], b[0].shape[2], b[0].shape[3]), np.float32)
m = np.zeros((max_len), np.float32)
l = np.zeros((max_len, b[1].shape[1]), np.float32)
f[:b[0].shape[0]] = b[0]
m[:b[0].shape[0]] = 1
l[:b[0].shape[0], :] = b[1]
new_batch.append([video_to_tensor(f), torch.from_numpy(m), torch.from_numpy(l), b[2]])
return default_collate(new_batch)
示例6: __init__
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle
self.batch_idx = 0
self.n_samples = len(dataset)
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
self.init_kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super().__init__(sampler=self.sampler, **self.init_kwargs)
示例7: make_collate_fn
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def make_collate_fn(padding_values):
def _collate_fn(batch):
for name, padding_value in padding_values.items():
lengths = [len(sample[name]) for sample in batch]
max_length = max(lengths)
for n, size in enumerate(lengths):
p = max_length - size
if p:
pad_width = [(0, p)] + [(0, 0)] * (batch[n][name].ndim - 1)
if padding_value == "edge":
batch[n][name] = np.pad(
batch[n][name], pad_width,
mode="edge")
else:
batch[n][name] = np.pad(
batch[n][name], pad_width,
mode="constant", constant_values=padding_value)
return default_collate(batch)
return _collate_fn
示例8: detection_collate
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on 0 dim
"""
custom = defaultdict(list)
custom_keys = ['target_size', ]
for sample in batch:
for k in custom_keys:
custom[k] += [sample[k]]
other = {k: default_collate([b[k] for b in batch]) for k in
filter(lambda x: x not in custom, batch[0].keys())}
return {**other, **custom}
示例9: custom_collate
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def custom_collate(batch, del_orig_labels=False):
keys = ['target_size', 'target_size_feats', 'alphas', 'target_level']
values = {}
for k in keys:
if k in batch[0]:
values[k] = batch[0][k]
for b in batch:
if del_orig_labels: del b['original_labels']
for k in values.keys():
del b[k]
if 'mux_indices' in b:
b['mux_indices'] = b['mux_indices'].view(-1)
batch = default_collate(batch)
# if 'image_next' in batch:
# batch['image'] = torch.cat([batch['image'], batch['image_next']], dim=0).contiguous()
# del batch['image_next']
for k, v in values.items():
batch[k] = v
return batch
示例10: collate_minibatch
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collate_minibatch(list_of_blobs):
"""Stack samples seperately and return a list of minibatches
A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
Hence, we need to stack smaples from each minibatch seperately.
"""
Batch = {key: [] for key in list_of_blobs[0]}
# Because roidb consists of entries of variable length, it can't be batch into a tensor.
# So we keep roidb in the type of "list of ndarray".
list_of_roidb = [blobs.pop('roidb') for blobs in list_of_blobs]
for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
mini_list = list_of_blobs[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
# Pad image data
mini_list = pad_image_data(mini_list)
minibatch = default_collate(mini_list)
minibatch['roidb'] = list_of_roidb[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
for key in minibatch:
Batch[key].append(minibatch[key])
return Batch
示例11: decimal_friendly_collate
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def decimal_friendly_collate(batch):
"""A wrapper on top of ``default_collate`` function that allows decimal.Decimal types to be collated.
We use ``decimal.Decimal`` types in petastorm dataset to represent timestamps. PyTorch's ``default_collate``
implementation does not support collating ``decimal.Decimal`` types. ``decimal_friendly_collate`` collates
``decimal.Decimal`` separately and then combines with the rest of the fields collated by a standard
``default_collate``.
:param batch: A list of dictionaries to collate
:return: A dictionary of lists/pytorch.Tensor types
"""
if isinstance(batch[0], decimal.Decimal):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], _string_classes):
return batch
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [decimal_friendly_collate(samples) for samples in transposed]
else:
return default_collate(batch)
示例12: __init__
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers,
collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle
self.batch_idx = 0
self.n_samples = len(dataset)
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
self.init_kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)
示例13: collate_minibatch
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def collate_minibatch(list_of_blobs):
"""Stack samples seperately and return a list of minibatches
A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
Hence, we need to stack smaples from each minibatch seperately.
"""
Batch = {key: [] for key in list_of_blobs[0]}
# Because roidb consists of entries of variable length, it can't be batch into a tensor.
# So we keep roidb in the type of "list of ndarray".
lists = []
for blobs in list_of_blobs:
lists.append({'data' : blobs.pop('data'),
'rois' : blobs.pop('rois'),
'labels' : blobs.pop('labels')})
for i in range(0, len(list_of_blobs), cfg.TRAIN.IMS_PER_BATCH):
mini_list = lists[i:(i + cfg.TRAIN.IMS_PER_BATCH)]
minibatch = default_collate(mini_list)
for key in minibatch:
Batch[key].append(minibatch[key])
return Batch
示例14: __init__
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
self.validation_split = validation_split
self.shuffle = shuffle
self.batch_idx = 0
self.n_samples = len(dataset)
self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
self.init_kwargs = {
'dataset': dataset,
'batch_size': batch_size,
'shuffle': self.shuffle,
'collate_fn': collate_fn,
'num_workers': num_workers
}
super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)
示例15: __init__
# 需要导入模块: from torch.utils.data import dataloader [as 别名]
# 或者: from torch.utils.data.dataloader import default_collate [as 别名]
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, base_seed=None, worker_init_fn=None, worker_init_args=None, worker_init_kwargs=None,
worker_recv_fn=None, **kwargs):
worker_init_args = worker_init_args if worker_init_args is not None else [tuple() for _ in range(num_workers)]
worker_init_kwargs = worker_init_kwargs if worker_init_kwargs is not None else [{} for _ in range(num_workers)]
base_seed = base_seed if base_seed is not None else gen_seed()
self.worker_recv_fn = worker_recv_fn
if worker_recv_fn is not None:
self.pipe_master = DataLoaderPipeMaster(num_workers)
else:
self.pipe_master = None
worker_init_fn = _InitFunctionWrapper(
base_seed, worker_init_fn, worker_init_args, worker_init_kwargs,
self.pipe_master, DataLoaderPipeSlave(worker_recv_fn)
)
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn, **kwargs)