当前位置: 首页>>代码示例>>Python>>正文


Python _six.string_classes方法代码示例

本文整理汇总了Python中torch._six.string_classes方法的典型用法代码示例。如果您正苦于以下问题:Python _six.string_classes方法的具体用法?Python _six.string_classes怎么用?Python _six.string_classes使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch._six的用法示例。


在下文中一共展示了_six.string_classes方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: collate

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def collate(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int_classes):
            return torch.tensor(batch)
        elif isinstance(elem, string_classes):
            return batch
        elif isinstance(elem, container_abcs.Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(s) for s in zip(*batch)))
        elif isinstance(elem, container_abcs.Sequence):
            return [self.collate(s) for s in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem))) 
开发者ID:rusty1s,项目名称:pytorch_geometric,代码行数:22,代码来源:dataloader.py

示例2: recursive_to

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def recursive_to(item, device):
    # language=rst
    """
    Recursively transfers everything contained in item to the target
    device.

    :param item: An individual tensor or container of tensors.
    :param device: ``torch.device`` pointing to ``"cuda"`` or ``"cpu"``.

    :return: A version of the item that has been sent to a device.
    """

    if isinstance(item, torch.Tensor):
        return item.to(device)
    elif isinstance(item, (string_classes, int, float, bool)):
        return item
    elif isinstance(item, container_abcs.Mapping):
        return {key: recursive_to(item[key], device) for key in item}
    elif isinstance(item, tuple) and hasattr(item, "_fields"):
        return type(item)(*(recursive_to(i, device) for i in item))
    elif isinstance(item, container_abcs.Sequence):
        return [recursive_to(i, device) for i in item]
    else:
        raise NotImplementedError(f"Target type {type(item)} not supported.") 
开发者ID:BindsNET,项目名称:bindsnet,代码行数:26,代码来源:base_pipeline.py

示例3: fast_batch_collator

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def fast_batch_collator(batched_inputs):
    """
    A simple batch collator for most common reid tasks
    """
    elem = batched_inputs[0]
    if isinstance(elem, torch.Tensor):
        out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
        for i, tensor in enumerate(batched_inputs):
            out[i] += tensor
        return out

    elif isinstance(elem, container_abcs.Mapping):
        return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}

    elif isinstance(elem, float):
        return torch.tensor(batched_inputs, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batched_inputs)
    elif isinstance(elem, string_classes):
        return batched_inputs 
开发者ID:JDAI-CV,项目名称:fast-reid,代码行数:22,代码来源:build.py

示例4: applier

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def applier(value, fn):
    if isinstance(value, torch.Tensor):
        return fn(value)
    elif isinstance(value, string_classes):
        return value
    elif isinstance(value, np.ndarray):
        return value
    elif hasattr(value, "to"): # Allow handling of custom batch classes
        return fn(value)
    elif isinstance(value, container_abcs.Mapping):
        return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
    elif isinstance(value, container_abcs.Iterable):
        return type(value)(applier(v, fn) for v in value)
    else:
        # Do I want this to fire off even if someone chooses to pass something ordinary like
        # an int or float?  May be more annoying than it's worth.
        # print("Warning:  unrecognized type in applier.  If your input data is a custom class, "
        #     "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
        #     "Amp will check for your custom to() and invoke it to cast the batch's "
        #     "floating-point Tensors to the appropriate type. "
        #     "Also, if your data is a custom class, it is your responsibility to ensure that "
        #     "any Tensors you want to be cuda are already cuda."
        return value 
开发者ID:NVIDIA,项目名称:apex,代码行数:25,代码来源:_initialize.py

示例5: convert_tensor

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def convert_tensor(input_, device=None):
    if torch.is_tensor(input_):
        if device:
            input_ = input_.to(device=device)
        return input_
    elif isinstance(input_, string_classes):
        return input_
    elif isinstance(input_, collections.Mapping):
        return {k: convert_tensor(sample, device=device) for k, sample in input_.items()}
    elif isinstance(input_, collections.Sequence):
        return [convert_tensor(sample, device=device) for sample in input_]
    else:
        raise TypeError(("input must contain tensors, dicts or lists; found {}"
                         .format(type(input_)))) 
开发者ID:hrhodin,项目名称:UnsupervisedGeometryAwareRepresentationLearning,代码行数:16,代码来源:_utils.py

示例6: concatenate_cache

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def concatenate_cache(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        return torch.cat(batch, 0, out=out)  # the main difference is here
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(error_msg_fmt.format(elem.dtype))
            return concatenate_cache([torch.from_numpy(b) for b in batch])
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int_classes):
        return torch.tensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: concatenate_cache([d[key] for d in batch])
                for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):
        return type(batch[0])(*(concatenate_cache(samples)
                                for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):  # also some diffs here
        # just unpack
        return [s_ for s in batch for s_ in s]

    raise TypeError((error_msg_fmt.format(type(batch[0])))) 
开发者ID:facebookresearch,项目名称:c3dpo_nrsfm,代码行数:36,代码来源:cache_preds.py

示例7: default_collate

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def default_collate(batch):
    "Puts each data field into a tensor with outer dimension batch size"

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if torch.is_tensor(batch[0]):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:XiaLiPKU,项目名称:EMANet,代码行数:41,代码来源:dataloader.py

示例8: pin_memory_batch

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def pin_memory_batch(batch):
    if torch.is_tensor(batch):
        return batch.pin_memory()
    elif isinstance(batch, string_classes):
        return batch
    elif isinstance(batch, collections.Mapping):
        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
    elif isinstance(batch, collections.Sequence):
        return [pin_memory_batch(sample) for sample in batch]
    else:
        return batch 
开发者ID:XiaLiPKU,项目名称:EMANet,代码行数:13,代码来源:dataloader.py

示例9: default_collate

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def default_collate(batch):
    """Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        return torch.stack(batch, 0)
    elif (
        elem_type.__module__ == "numpy"
        and elem_type.__name__ != "str_"
        and elem_type.__name__ != "string_"
    ):  # pragma: no cover
        elem = batch[0]
        if elem_type.__name__ == "ndarray":
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith("float") else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):  # pragma: no cover
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):  # pragma: no cover
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):  # pragma: no cover
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):  # pragma: no cover
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], container_abcs.Sequence):  # pragma: no cover
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:OpenMined,项目名称:PySyft,代码行数:33,代码来源:dataloader.py

示例10: mt_collate

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def mt_collate(batch):
    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if torch.is_tensor(batch[0]):
        stacked = torch.stack(batch, 0)
        return stacked
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return __numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: mt_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [mt_collate(samples) for samples in transposed]

    return batch 
开发者ID:perone,项目名称:medicaltorch,代码行数:32,代码来源:datasets.py

示例11: apply_to_type

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def apply_to_type(input_, input_type, func):
    """Apply a function on a object of `input_type` or mapping, or sequence of objects of `input_type`.
    """
    if isinstance(input_, input_type):
        return func(input_)
    elif isinstance(input_, string_classes):
        return input_
    elif isinstance(input_, collections.Mapping):
        return {k: apply_to_type(sample, input_type, func) for k, sample in input_.items()}
    elif isinstance(input_, collections.Sequence):
        return [apply_to_type(sample, input_type, func) for sample in input_]
    else:
        raise TypeError(("input must contain {}, dicts or lists; found {}"
                         .format(input_type, type(input_)))) 
开发者ID:leokarlin,项目名称:LaSO,代码行数:16,代码来源:utils.py

示例12: collate

# 需要导入模块: from torch import _six [as 别名]
# 或者: from torch._six import string_classes [as 别名]
def collate(batch, *, root=True):
    "Puts each data field into a tensor with outer dimension batch size"

    if len(batch) == 0:
        return batch

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if torch.is_tensor(batch[0]):
        return default_collate(batch)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        return default_collate(batch)
    elif isinstance(batch[0], int_classes):
        return batch
    elif isinstance(batch[0], float):
        return batch
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], CameraIntrinsics):
        return batch
    elif isinstance(batch[0], Mapping):
        if root:
            return {key: collate([d[key] for d in batch], root=False) for key in batch[0]}
        else:
            return batch
    elif isinstance(batch[0], Sequence):
        return [collate(e, root=False) for e in batch]

    raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:anibali,项目名称:margipose,代码行数:32,代码来源:__init__.py


注:本文中的torch._six.string_classes方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。