本文整理匯總了Python中torch._six.container_abcs.Mapping方法的典型用法代碼示例。如果您正苦於以下問題:Python container_abcs.Mapping方法的具體用法?Python container_abcs.Mapping怎麽用?Python container_abcs.Mapping使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch._six.container_abcs
的用法示例。
在下文中一共展示了container_abcs.Mapping方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: collate
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def collate(self, batch):
elem = batch[0]
if isinstance(elem, Data):
return Batch.from_data_list(batch, self.follow_batch)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: self.collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(self.collate(s) for s in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
return [self.collate(s) for s in zip(*batch)]
raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))
示例2: recursive_copy_to_gpu
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def recursive_copy_to_gpu(value: Any, non_blocking: Dict = True) -> Any:
"""
Recursively searches lists, tuples, dicts and copies tensors to GPU if
possible. Non-tensor values are passed as-is in the result.
Note: These are all copies, so if there are two objects that reference
the same object, then after this call, there will be two different objects
referenced on the GPU.
"""
if hasattr(value, "cuda"):
return value.cuda(non_blocking=non_blocking)
elif isinstance(value, list) or isinstance(value, tuple):
gpu_val = []
for val in value:
gpu_val.append(recursive_copy_to_gpu(val, non_blocking=non_blocking))
return gpu_val if isinstance(value, list) else tuple(gpu_val)
elif isinstance(value, container_abcs.Mapping):
gpu_val = {}
for key, val in value.items():
gpu_val[key] = recursive_copy_to_gpu(val, non_blocking=non_blocking)
return gpu_val
return value
示例3: recursive_to
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def recursive_to(item, device):
# language=rst
"""
Recursively transfers everything contained in item to the target
device.
:param item: An individual tensor or container of tensors.
:param device: ``torch.device`` pointing to ``"cuda"`` or ``"cpu"``.
:return: A version of the item that has been sent to a device.
"""
if isinstance(item, torch.Tensor):
return item.to(device)
elif isinstance(item, (string_classes, int, float, bool)):
return item
elif isinstance(item, container_abcs.Mapping):
return {key: recursive_to(item[key], device) for key in item}
elif isinstance(item, tuple) and hasattr(item, "_fields"):
return type(item)(*(recursive_to(i, device) for i in item))
elif isinstance(item, container_abcs.Sequence):
return [recursive_to(i, device) for i in item]
else:
raise NotImplementedError(f"Target type {type(item)} not supported.")
示例4: fast_batch_collator
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def fast_batch_collator(batched_inputs):
"""
A simple batch collator for most common reid tasks
"""
elem = batched_inputs[0]
if isinstance(elem, torch.Tensor):
out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
for i, tensor in enumerate(batched_inputs):
out[i] += tensor
return out
elif isinstance(elem, container_abcs.Mapping):
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
elif isinstance(elem, float):
return torch.tensor(batched_inputs, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batched_inputs)
elif isinstance(elem, string_classes):
return batched_inputs
示例5: concatenate_cache
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def concatenate_cache(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
return torch.cat(batch, 0, out=out) # the main difference is here
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(error_msg_fmt.format(elem.dtype))
return concatenate_cache([torch.from_numpy(b) for b in batch])
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(batch[0], int_classes):
return torch.tensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):
return {key: concatenate_cache([d[key] for d in batch])
for key in batch[0]}
elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):
return type(batch[0])(*(concatenate_cache(samples)
for samples in zip(*batch)))
elif isinstance(batch[0], container_abcs.Sequence): # also some diffs here
# just unpack
return [s_ for s in batch for s_ in s]
raise TypeError((error_msg_fmt.format(type(batch[0]))))
示例6: default_collate
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def default_collate(batch):
"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
return torch.stack(batch, 0)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
): # pragma: no cover
elem = batch[0]
if elem_type.__name__ == "ndarray":
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith("float") else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes): # pragma: no cover
return torch.LongTensor(batch)
elif isinstance(batch[0], float): # pragma: no cover
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes): # pragma: no cover
return batch
elif isinstance(batch[0], container_abcs.Mapping): # pragma: no cover
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], container_abcs.Sequence): # pragma: no cover
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
示例7: update
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def update(self, buffers):
r"""Update the :class:`~torch.nn.BufferDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
.. note::
If :attr:`buffers` is an ``OrderedDict``, a :class:`~torch.nn.BufferDict`,
or an iterable of key-value pairs, the order of new elements in it is
preserved.
Arguments:
buffers (iterable): a mapping (dictionary) from string to
:class:`~torch.Tensor`, or an iterable of
key-value pairs of type (string, :class:`~torch.Tensor`)
"""
if not isinstance(buffers, container_abcs.Iterable):
raise TypeError(
"BuffersDict.update should be called with an "
"iterable of key/value pairs, but got " + type(buffers).__name__
)
if isinstance(buffers, container_abcs.Mapping):
if isinstance(buffers, (OrderedDict, BufferDict)):
for key, buffer in buffers.items():
self[key] = buffer
else:
for key, buffer in sorted(buffers.items()):
self[key] = buffer
else:
for j, p in enumerate(buffers):
if not isinstance(p, container_abcs.Iterable):
raise TypeError(
"BufferDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
if not len(p) == 2:
raise ValueError(
"BufferDict update sequence element "
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
self[p[0]] = p[1]
示例8: time_aware_collate
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def time_aware_collate(batch):
# language=rst
"""
Puts each data field into a tensor with dimensions ``[time, batch size, ...]``
Interpretation of dimensions being input:
- 0 dim (,) - (1, batch_size, 1)
- 1 dim (time,) - (time, batch_size, 1)
- >2 dim (time, n_0, ...) - (time, batch_size, n_0, ...)
"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
# catch 0 and 1 dimension cases and view as specified
if elem.dim() == 0:
batch = [x.view((1, 1)) for x in batch]
elif elem.dim() == 1:
batch = [x.view((x.shape[0], 1)) for x in batch]
out = None
if safe_worker_check():
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 1, out=out)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
elem = batch[0]
if elem_type.__name__ == "ndarray":
# array of string classes and object
if (
pytorch_collate.np_str_obj_array_pattern.search(elem.dtype.str)
is not None
):
raise TypeError(
pytorch_collate.default_collate_err_msg_format.format(elem.dtype)
)
return time_aware_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: time_aware_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(time_aware_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [time_aware_collate(samples) for samples in transposed]
raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type))
示例9: _collate_else
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def _collate_else(batch, collate_func):
"""
Handles recursion in the else case for these special collate functions
This is duplicates all non-tensor cases from `torch_data.dataloader.default_collate`
This also contains support for collating slices.
"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], slice):
batch = default_collate([{
'start': sl.start,
'stop': sl.stop,
'step': 1 if sl.step is None else sl.step
} for sl in batch])
return batch
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):
# Hack the mapping collation implementation to print error info
if _DEBUG:
collated = {}
try:
for key in batch[0]:
collated[key] = collate_func([d[key] for d in batch])
except Exception:
print('\n!!Error collating key = {!r}\n'.format(key))
raise
return collated
else:
return {key: collate_func([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(batch[0], container_abcs.Sequence):
transposed = zip(*batch)
return [collate_func(samples) for samples in transposed]
else:
raise TypeError((error_msg.format(type(batch[0]))))
示例10: move
# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def move(xpu, data, **kwargs):
"""
Moves the model onto the primary GPU or CPU.
If the data is nested in a container (e.g. a dict or list) then this
funciton is applied recursively to all values in the container.
Note:
This works by calling the `.to` method, which works inplace for
torch Modules, but is not implace for raw Tensors.
Args:
data (torch.Module | torch.Tensor | Collection):
raw data or a collection containing raw data.
**kwargs : forwarded to `data.cuda`
Returns:
torch.Tensor: the tensor with a dtype for this device
Example:
>>> data = torch.FloatTensor([0])
>>> if torch.cuda.is_available():
>>> xpu = XPU.coerce('gpu')
>>> assert isinstance(xpu.move(data), torch.cuda.FloatTensor)
>>> xpu = XPU.coerce('cpu')
>>> assert isinstance(xpu.move(data), torch.FloatTensor)
>>> assert isinstance(xpu.move([data])[0], torch.FloatTensor)
>>> assert isinstance(xpu.move({0: data})[0], torch.FloatTensor)
>>> assert isinstance(xpu.move({data}), set)
"""
try:
if xpu.is_gpu():
return data.to(xpu._main_device_id, **kwargs)
else:
return data.to('cpu')
except AttributeError:
# Recursive move
if isinstance(data, container_abcs.Mapping):
cls = data.__class__
return cls((k, xpu.move(v)) for k, v in data.items())
elif isinstance(data, (container_abcs.Sequence, container_abcs.Set)):
cls = data.__class__
return cls(xpu.move(v) for v in data)
else:
raise TypeError('Unknown type {}'.format(type(data)))