當前位置: 首頁>>代碼示例>>Python>>正文


Python container_abcs.Mapping方法代碼示例

本文整理匯總了Python中torch._six.container_abcs.Mapping方法的典型用法代碼示例。如果您正苦於以下問題:Python container_abcs.Mapping方法的具體用法?Python container_abcs.Mapping怎麽用?Python container_abcs.Mapping使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch._six.container_abcs的用法示例。


在下文中一共展示了container_abcs.Mapping方法的10個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: collate

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def collate(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int_classes):
            return torch.tensor(batch)
        elif isinstance(elem, string_classes):
            return batch
        elif isinstance(elem, container_abcs.Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(s) for s in zip(*batch)))
        elif isinstance(elem, container_abcs.Sequence):
            return [self.collate(s) for s in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) 
開發者ID:rusty1s,項目名稱:pytorch_geometric,代碼行數:22,代碼來源:dataloader.py

示例2: recursive_copy_to_gpu

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def recursive_copy_to_gpu(value: Any, non_blocking: Dict = True) -> Any:
    """
    Recursively searches lists, tuples, dicts and copies tensors to GPU if
    possible. Non-tensor values are passed as-is in the result.
    Note:  These are all copies, so if there are two objects that reference
    the same object, then after this call, there will be two different objects
    referenced on the GPU.
    """
    if hasattr(value, "cuda"):
        return value.cuda(non_blocking=non_blocking)
    elif isinstance(value, list) or isinstance(value, tuple):
        gpu_val = []
        for val in value:
            gpu_val.append(recursive_copy_to_gpu(val, non_blocking=non_blocking))

        return gpu_val if isinstance(value, list) else tuple(gpu_val)
    elif isinstance(value, container_abcs.Mapping):
        gpu_val = {}
        for key, val in value.items():
            gpu_val[key] = recursive_copy_to_gpu(val, non_blocking=non_blocking)

        return gpu_val

    return value 
開發者ID:facebookresearch,項目名稱:ClassyVision,代碼行數:26,代碼來源:util.py

示例3: recursive_to

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def recursive_to(item, device):
    # language=rst
    """
    Recursively transfers everything contained in item to the target
    device.

    :param item: An individual tensor or container of tensors.
    :param device: ``torch.device`` pointing to ``"cuda"`` or ``"cpu"``.

    :return: A version of the item that has been sent to a device.
    """

    if isinstance(item, torch.Tensor):
        return item.to(device)
    elif isinstance(item, (string_classes, int, float, bool)):
        return item
    elif isinstance(item, container_abcs.Mapping):
        return {key: recursive_to(item[key], device) for key in item}
    elif isinstance(item, tuple) and hasattr(item, "_fields"):
        return type(item)(*(recursive_to(i, device) for i in item))
    elif isinstance(item, container_abcs.Sequence):
        return [recursive_to(i, device) for i in item]
    else:
        raise NotImplementedError(f"Target type {type(item)} not supported.") 
開發者ID:BindsNET,項目名稱:bindsnet,代碼行數:26,代碼來源:base_pipeline.py

示例4: fast_batch_collator

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def fast_batch_collator(batched_inputs):
    """
    A simple batch collator for most common reid tasks
    """
    elem = batched_inputs[0]
    if isinstance(elem, torch.Tensor):
        out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
        for i, tensor in enumerate(batched_inputs):
            out[i] += tensor
        return out

    elif isinstance(elem, container_abcs.Mapping):
        return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}

    elif isinstance(elem, float):
        return torch.tensor(batched_inputs, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batched_inputs)
    elif isinstance(elem, string_classes):
        return batched_inputs 
開發者ID:JDAI-CV,項目名稱:fast-reid,代碼行數:22,代碼來源:build.py

示例5: concatenate_cache

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def concatenate_cache(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        return torch.cat(batch, 0, out=out)  # the main difference is here
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(error_msg_fmt.format(elem.dtype))
            return concatenate_cache([torch.from_numpy(b) for b in batch])
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int_classes):
        return torch.tensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: concatenate_cache([d[key] for d in batch])
                for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):
        return type(batch[0])(*(concatenate_cache(samples)
                                for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):  # also some diffs here
        # just unpack
        return [s_ for s in batch for s_ in s]

    raise TypeError((error_msg_fmt.format(type(batch[0])))) 
開發者ID:facebookresearch,項目名稱:c3dpo_nrsfm,代碼行數:36,代碼來源:cache_preds.py

示例6: default_collate

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def default_collate(batch):
    """Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        return torch.stack(batch, 0)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):  # pragma: no cover
        elem = batch[0]
        if elem_type.__name__ == "ndarray":
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith("float") else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):  # pragma: no cover
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):  # pragma: no cover
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):  # pragma: no cover
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):  # pragma: no cover
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], container_abcs.Sequence):  # pragma: no cover
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0])))) 
開發者ID:OpenMined,項目名稱:PySyft,代碼行數:33,代碼來源:dataloader.py

示例7: update

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def update(self, buffers):
        r"""Update the :class:`~torch.nn.BufferDict` with the key-value pairs from a
        mapping or an iterable, overwriting existing keys.

        .. note::
            If :attr:`buffers` is an ``OrderedDict``, a :class:`~torch.nn.BufferDict`,
            or an iterable of key-value pairs, the order of new elements in it is
            preserved.

        Arguments:
            buffers (iterable): a mapping (dictionary) from string to
                :class:`~torch.Tensor`, or an iterable of
                key-value pairs of type (string, :class:`~torch.Tensor`)
        """
        if not isinstance(buffers, container_abcs.Iterable):
            raise TypeError(
                "BuffersDict.update should be called with an "
                "iterable of key/value pairs, but got " + type(buffers).__name__
            )

        if isinstance(buffers, container_abcs.Mapping):
            if isinstance(buffers, (OrderedDict, BufferDict)):
                for key, buffer in buffers.items():
                    self[key] = buffer
            else:
                for key, buffer in sorted(buffers.items()):
                    self[key] = buffer
        else:
            for j, p in enumerate(buffers):
                if not isinstance(p, container_abcs.Iterable):
                    raise TypeError(
                        "BufferDict update sequence element "
                        "#" + str(j) + " should be Iterable; is" + type(p).__name__
                    )
                if not len(p) == 2:
                    raise ValueError(
                        "BufferDict update sequence element "
                        "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
                    )
                self[p[0]] = p[1] 
開發者ID:pytorch,項目名稱:botorch,代碼行數:42,代碼來源:torch.py

示例8: time_aware_collate

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def time_aware_collate(batch):
    # language=rst
    """
    Puts each data field into a tensor with dimensions ``[time, batch size, ...]``

    Interpretation of dimensions being input:
    -  0 dim (,) - (1, batch_size, 1)
    -  1 dim (time,) - (time, batch_size, 1)
    - >2 dim (time, n_0, ...) - (time, batch_size, n_0, ...)
    """
    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        # catch 0 and 1 dimension cases and view as specified
        if elem.dim() == 0:
            batch = [x.view((1, 1)) for x in batch]
        elif elem.dim() == 1:
            batch = [x.view((x.shape[0], 1)) for x in batch]

        out = None
        if safe_worker_check():
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 1, out=out)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):
        elem = batch[0]
        if elem_type.__name__ == "ndarray":
            # array of string classes and object
            if (
                pytorch_collate.np_str_obj_array_pattern.search(elem.dtype.str)
                is not None
            ):
                raise TypeError(
                    pytorch_collate.default_collate_err_msg_format.format(elem.dtype)
                )

            return time_aware_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: time_aware_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
        return elem_type(*(time_aware_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [time_aware_collate(samples) for samples in transposed]

    raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type)) 
開發者ID:BindsNET,項目名稱:bindsnet,代碼行數:63,代碼來源:collate.py

示例9: _collate_else

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def _collate_else(batch, collate_func):
    """
    Handles recursion in the else case for these special collate functions

    This is duplicates all non-tensor cases from `torch_data.dataloader.default_collate`
    This also contains support for collating slices.
    """
    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], slice):
        batch = default_collate([{
            'start': sl.start,
            'stop': sl.stop,
            'step': 1 if sl.step is None else sl.step
        } for sl in batch])
        return batch
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        # Hack the mapping collation implementation to print error info
        if _DEBUG:
            collated = {}
            try:
                for key in batch[0]:
                    collated[key] = collate_func([d[key] for d in batch])
            except Exception:
                print('\n!!Error collating key = {!r}\n'.format(key))
                raise
            return collated
        else:
            return {key: collate_func([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [collate_func(samples) for samples in transposed]
    else:
        raise TypeError((error_msg.format(type(batch[0])))) 
開發者ID:Erotemic,項目名稱:netharn,代碼行數:56,代碼來源:collate.py

示例10: move

# 需要導入模塊: from torch._six import container_abcs [as 別名]
# 或者: from torch._six.container_abcs import Mapping [as 別名]
def move(xpu, data, **kwargs):
        """
        Moves the model onto the primary GPU or CPU.

        If the data is nested in a container (e.g. a dict or list) then this
        funciton is applied recursively to all values in the container.

        Note:
            This works by calling the `.to` method, which works inplace for
            torch Modules, but is not implace for raw Tensors.

        Args:
            data (torch.Module | torch.Tensor | Collection):
                raw data or a collection containing raw data.
            **kwargs : forwarded to `data.cuda`

        Returns:
            torch.Tensor: the tensor with a dtype for this device

        Example:
            >>> data = torch.FloatTensor([0])
            >>> if torch.cuda.is_available():
            >>>     xpu = XPU.coerce('gpu')
            >>>     assert isinstance(xpu.move(data), torch.cuda.FloatTensor)
            >>> xpu = XPU.coerce('cpu')
            >>> assert isinstance(xpu.move(data), torch.FloatTensor)
            >>> assert isinstance(xpu.move([data])[0], torch.FloatTensor)
            >>> assert isinstance(xpu.move({0: data})[0], torch.FloatTensor)
            >>> assert isinstance(xpu.move({data}), set)
        """
        try:
            if xpu.is_gpu():
                return data.to(xpu._main_device_id, **kwargs)
            else:
                return data.to('cpu')
        except AttributeError:
            # Recursive move
            if isinstance(data, container_abcs.Mapping):
                cls = data.__class__
                return cls((k, xpu.move(v)) for k, v in data.items())
            elif isinstance(data, (container_abcs.Sequence, container_abcs.Set)):
                cls = data.__class__
                return cls(xpu.move(v) for v in data)
            else:
                raise TypeError('Unknown type {}'.format(type(data))) 
開發者ID:Erotemic,項目名稱:netharn,代碼行數:47,代碼來源:device.py


注:本文中的torch._six.container_abcs.Mapping方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。