當前位置: 首頁>>代碼示例>>Python>>正文


Python dataloader.numpy_type_map方法代碼示例

本文整理匯總了Python中torch.utils.data.dataloader.numpy_type_map方法的典型用法代碼示例。如果您正苦於以下問題:Python dataloader.numpy_type_map方法的具體用法?Python dataloader.numpy_type_map怎麽用?Python dataloader.numpy_type_map使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.utils.data.dataloader的用法示例。


在下文中一共展示了dataloader.numpy_type_map方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: gather

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import numpy_type_map [as 別名]
def gather(outputs, target_device, dim=0):
    r"""
    Gathers variables from different GPUs on a specified device
      (-1 means the CPU).
    """
    error_msg = "outputs must contain tensors, numbers, dicts or lists; found {}"

    def gather_map(outputs):
        out = outputs[0]
        elem_type = type(out)
        if isinstance(out, Variable):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None
        if isinstance(out, collections.Sequence):
            return type(out)(map(gather_map, zip(*outputs)))
        elif isinstance(out, collections.Mapping):
            return {key: gather_map([d[key] for d in outputs]) for key in out}
        elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
                and elem_type.__name__ != 'string_':
            elem = out
            if elem_type.__name__ == 'ndarray':
                # array of string classes and object
                if re.search('[SaUO]', elem.dtype.str) is not None:
                    raise TypeError(error_msg.format(elem.dtype))

                return Variable(torch.from_numpy(np.concatenate(outputs, dim)))
            if elem.shape == ():  # scalars
                py_type = float if elem.dtype.name.startswith('float') else int
                return Variable(numpy_type_map[elem.dtype.name](list(map(py_type, outputs))))
        elif isinstance(out, int_classes):
            return Variable(torch.LongTensor(outputs))
        elif isinstance(out, float):
            return Variable(torch.DoubleTensor(outputs))
        elif isinstance(out, string_classes):
            return outputs

        raise TypeError((error_msg.format(elem_type)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        return gather_map(outputs)
    finally:
        gather_map = None 
開發者ID:roytseng-tw,項目名稱:Detectron.pytorch,代碼行數:47,代碼來源:scatter_gather.py

示例2: concat_collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import numpy_type_map [as 別名]
def concat_collate(batch):
    # type: (List[torch.Tensor]) -> torch.Tensor
    """
    Puts each data field into a tensor stacking along the first dimension.
    This is different to the default pytorch collate that stacks samples rather than
    concatenating them.

    :param batch: the input batch to be collated.
    """
    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.cat(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.cat([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: concat_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [concat_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0])))) 
開發者ID:aimagelab,項目名稱:novelty-detection,代碼行數:47,代碼來源:utils.py

示例3: collate

# 需要導入模塊: from torch.utils.data import dataloader [as 別名]
# 或者: from torch.utils.data.dataloader import numpy_type_map [as 別名]
def collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        res =  {key: collate([d[key] for d in batch]) for key in batch[0] if key!='instance_mask'}
        if 'instance_mask' in batch[0]:
            max_obj = max([d['instance_mask'].shape[0] for d in batch])
            instance_mask = torch.zeros(len(batch),max_obj,*(batch[0]['instance_mask'].shape[1:]))
            for i in range(len(batch)):
                num_obj = batch[i]['instance_mask'].shape[0]
                instance_mask[i,:num_obj] = torch.from_numpy(batch[i]['instance_mask'])
            res.update({'instance_mask':instance_mask})
        return res
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0])))) 
開發者ID:CaoWGG,項目名稱:CenterNet-CondInst,代碼行數:49,代碼來源:utils.py


注:本文中的torch.utils.data.dataloader.numpy_type_map方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。