当前位置: 首页>>代码示例>>Python>>正文


Python container_abcs.Mapping方法代码示例

本文整理汇总了Python中torch._six.container_abcs.Mapping方法的典型用法代码示例。如果您正苦于以下问题:Python container_abcs.Mapping方法的具体用法?Python container_abcs.Mapping怎么用?Python container_abcs.Mapping使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch._six.container_abcs的用法示例。


在下文中一共展示了container_abcs.Mapping方法的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: collate

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def collate(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int_classes):
            return torch.tensor(batch)
        elif isinstance(elem, string_classes):
            return batch
        elif isinstance(elem, container_abcs.Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(s) for s in zip(*batch)))
        elif isinstance(elem, container_abcs.Sequence):
            return [self.collate(s) for s in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:22,代码来源:dataloader.py

示例2: recursive_copy_to_gpu

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def recursive_copy_to_gpu(value: Any, non_blocking: Dict = True) -> Any:
    """
    Recursively searches lists, tuples, dicts and copies tensors to GPU if
    possible. Non-tensor values are passed as-is in the result.
    Note:  These are all copies, so if there are two objects that reference
    the same object, then after this call, there will be two different objects
    referenced on the GPU.
    """
    if hasattr(value, "cuda"):
        return value.cuda(non_blocking=non_blocking)
    elif isinstance(value, list) or isinstance(value, tuple):
        gpu_val = []
        for val in value:
            gpu_val.append(recursive_copy_to_gpu(val, non_blocking=non_blocking))

        return gpu_val if isinstance(value, list) else tuple(gpu_val)
    elif isinstance(value, container_abcs.Mapping):
        gpu_val = {}
        for key, val in value.items():
            gpu_val[key] = recursive_copy_to_gpu(val, non_blocking=non_blocking)

        return gpu_val

    return value 
开发者ID:facebookresearch,项目名称:ClassyVision,代码行数:26,代码来源:util.py

示例3: recursive_to

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def recursive_to(item, device):
    # language=rst
    """
    Recursively transfers everything contained in item to the target
    device.

    :param item: An individual tensor or container of tensors.
    :param device: ``torch.device`` pointing to ``"cuda"`` or ``"cpu"``.

    :return: A version of the item that has been sent to a device.
    """

    if isinstance(item, torch.Tensor):
        return item.to(device)
    elif isinstance(item, (string_classes, int, float, bool)):
        return item
    elif isinstance(item, container_abcs.Mapping):
        return {key: recursive_to(item[key], device) for key in item}
    elif isinstance(item, tuple) and hasattr(item, "_fields"):
        return type(item)(*(recursive_to(i, device) for i in item))
    elif isinstance(item, container_abcs.Sequence):
        return [recursive_to(i, device) for i in item]
    else:
        raise NotImplementedError(f"Target type {type(item)} not supported.") 
开发者ID:BindsNET,项目名称:bindsnet,代码行数:26,代码来源:base_pipeline.py

示例4: fast_batch_collator

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def fast_batch_collator(batched_inputs):
    """
    A simple batch collator for most common reid tasks
    """
    elem = batched_inputs[0]
    if isinstance(elem, torch.Tensor):
        out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
        for i, tensor in enumerate(batched_inputs):
            out[i] += tensor
        return out

    elif isinstance(elem, container_abcs.Mapping):
        return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}

    elif isinstance(elem, float):
        return torch.tensor(batched_inputs, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batched_inputs)
    elif isinstance(elem, string_classes):
        return batched_inputs 
开发者ID:JDAI-CV,项目名称:fast-reid,代码行数:22,代码来源:build.py

示例5: concatenate_cache

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def concatenate_cache(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        return torch.cat(batch, 0, out=out)  # the main difference is here
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(error_msg_fmt.format(elem.dtype))
            return concatenate_cache([torch.from_numpy(b) for b in batch])
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int_classes):
        return torch.tensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: concatenate_cache([d[key] for d in batch])
                for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):
        return type(batch[0])(*(concatenate_cache(samples)
                                for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):  # also some diffs here
        # just unpack
        return [s_ for s in batch for s_ in s]

    raise TypeError((error_msg_fmt.format(type(batch[0])))) 
开发者ID:facebookresearch,项目名称:c3dpo_nrsfm,代码行数:36,代码来源:cache_preds.py

示例6: default_collate

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def default_collate(batch):
    """Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        return torch.stack(batch, 0)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):  # pragma: no cover
        elem = batch[0]
        if elem_type.__name__ == "ndarray":
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith("float") else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):  # pragma: no cover
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):  # pragma: no cover
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):  # pragma: no cover
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):  # pragma: no cover
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], container_abcs.Sequence):  # pragma: no cover
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:OpenMined,项目名称:PySyft,代码行数:33,代码来源:dataloader.py

示例7: update

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def update(self, buffers):
        r"""Update the :class:`~torch.nn.BufferDict` with the key-value pairs from a
        mapping or an iterable, overwriting existing keys.

        .. note::
            If :attr:`buffers` is an ``OrderedDict``, a :class:`~torch.nn.BufferDict`,
            or an iterable of key-value pairs, the order of new elements in it is
            preserved.

        Arguments:
            buffers (iterable): a mapping (dictionary) from string to
                :class:`~torch.Tensor`, or an iterable of
                key-value pairs of type (string, :class:`~torch.Tensor`)
        """
        if not isinstance(buffers, container_abcs.Iterable):
            raise TypeError(
                "BuffersDict.update should be called with an "
                "iterable of key/value pairs, but got " + type(buffers).__name__
            )

        if isinstance(buffers, container_abcs.Mapping):
            if isinstance(buffers, (OrderedDict, BufferDict)):
                for key, buffer in buffers.items():
                    self[key] = buffer
            else:
                for key, buffer in sorted(buffers.items()):
                    self[key] = buffer
        else:
            for j, p in enumerate(buffers):
                if not isinstance(p, container_abcs.Iterable):
                    raise TypeError(
                        "BufferDict update sequence element "
                        "#" + str(j) + " should be Iterable; is" + type(p).__name__
                    )
                if not len(p) == 2:
                    raise ValueError(
                        "BufferDict update sequence element "
                        "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
                    )
                self[p[0]] = p[1] 
开发者ID:pytorch,项目名称:botorch,代码行数:42,代码来源:torch.py

示例8: time_aware_collate

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def time_aware_collate(batch):
    # language=rst
    """
    Puts each data field into a tensor with dimensions ``[time, batch size, ...]``

    Interpretation of dimensions being input:
    -  0 dim (,) - (1, batch_size, 1)
    -  1 dim (time,) - (time, batch_size, 1)
    - >2 dim (time, n_0, ...) - (time, batch_size, n_0, ...)
    """
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        # catch 0 and 1 dimension cases and view as specified
        if elem.dim() == 0:
            batch = [x.view((1, 1)) for x in batch]
        elif elem.dim() == 1:
            batch = [x.view((x.shape[0], 1)) for x in batch]

        out = None
        if safe_worker_check():
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 1, out=out)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):
        elem = batch[0]
        if elem_type.__name__ == "ndarray":
            # array of string classes and object
            if (
                pytorch_collate.np_str_obj_array_pattern.search(elem.dtype.str)
                is not None
            ):
                raise TypeError(
                    pytorch_collate.default_collate_err_msg_format.format(elem.dtype)
                )

            return time_aware_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: time_aware_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(*(time_aware_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [time_aware_collate(samples) for samples in transposed]

    raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type)) 
开发者ID:BindsNET,项目名称:bindsnet,代码行数:63,代码来源:collate.py

示例9: _collate_else

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def _collate_else(batch, collate_func):
    """
    Handles recursion in the else case for these special collate functions

    This is duplicates all non-tensor cases from `torch_data.dataloader.default_collate`
    This also contains support for collating slices.
    """
    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], slice):
        batch = default_collate([{
            'start': sl.start,
            'stop': sl.stop,
            'step': 1 if sl.step is None else sl.step
        } for sl in batch])
        return batch
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        # Hack the mapping collation implementation to print error info
        if _DEBUG:
            collated = {}
            try:
                for key in batch[0]:
                    collated[key] = collate_func([d[key] for d in batch])
            except Exception:
                print('\n!!Error collating key = {!r}\n'.format(key))
                raise
            return collated
        else:
            return {key: collate_func([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [collate_func(samples) for samples in transposed]
    else:
        raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:Erotemic,项目名称:netharn,代码行数:56,代码来源:collate.py

示例10: move

# 需要导入模块: from torch._six import container_abcs [as 别名]
# 或者: from torch._six.container_abcs import Mapping [as 别名]
def move(xpu, data, **kwargs):
        """
        Moves the model onto the primary GPU or CPU.

        If the data is nested in a container (e.g. a dict or list) then this
        funciton is applied recursively to all values in the container.

        Note:
            This works by calling the `.to` method, which works inplace for
            torch Modules, but is not implace for raw Tensors.

        Args:
            data (torch.Module | torch.Tensor | Collection):
                raw data or a collection containing raw data.
            **kwargs : forwarded to `data.cuda`

        Returns:
            torch.Tensor: the tensor with a dtype for this device

        Example:
            >>> data = torch.FloatTensor([0])
            >>> if torch.cuda.is_available():
            >>>     xpu = XPU.coerce('gpu')
            >>>     assert isinstance(xpu.move(data), torch.cuda.FloatTensor)
            >>> xpu = XPU.coerce('cpu')
            >>> assert isinstance(xpu.move(data), torch.FloatTensor)
            >>> assert isinstance(xpu.move([data])[0], torch.FloatTensor)
            >>> assert isinstance(xpu.move({0: data})[0], torch.FloatTensor)
            >>> assert isinstance(xpu.move({data}), set)
        """
        try:
            if xpu.is_gpu():
                return data.to(xpu._main_device_id, **kwargs)
            else:
                return data.to('cpu')
        except AttributeError:
            # Recursive move
            if isinstance(data, container_abcs.Mapping):
                cls = data.__class__
                return cls((k, xpu.move(v)) for k, v in data.items())
            elif isinstance(data, (container_abcs.Sequence, container_abcs.Set)):
                cls = data.__class__
                return cls(xpu.move(v) for v in data)
            else:
                raise TypeError('Unknown type {}'.format(type(data))) 
开发者ID:Erotemic,项目名称:netharn,代码行数:47,代码来源:device.py


注:本文中的torch._six.container_abcs.Mapping方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。