当前位置: 首页>>代码示例>>Python>>正文


Python Gather.apply方法代码示例

本文整理汇总了Python中torch.nn.parallel._functions.Gather.apply方法的典型用法代码示例。如果您正苦于以下问题:Python Gather.apply方法的具体用法?Python Gather.apply怎么用?Python Gather.apply使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在torch.nn.parallel._functions.Gather的用法示例。


在下文中一共展示了Gather.apply方法的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: dict_gather

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def dict_gather(outputs, target_device, dim=0):
    """
    Gathers variables from different GPUs on a specified device
      (-1 means the CPU), with dictionary support.
    """
    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, Variable):
            # MJY(20180330) HACK:: force nr_dims > 0
            if out.dim() == 0:
                outputs = [o.unsqueeze(0) for o in outputs]
            return Gather.apply(target_device, dim, *outputs)
        elif out is None:
            return None
        elif isinstance(out, collections.Mapping):
            return {k: gather_map([o[k] for o in outputs]) for k in out}
        elif isinstance(out, collections.Sequence):
            return type(out)(map(gather_map, zip(*outputs)))
    return gather_map(outputs) 
开发者ID:XiaLiPKU,项目名称:EMANet,代码行数:21,代码来源:data_parallel.py

示例2: dict_gather

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def dict_gather(outputs, target_device, dim=0):
    """
    Gathers variables from different GPUs on a specified device
      (-1 means the CPU), with dictionary support.
    """
    def gather_map(outputs):
        out = outputs[0]
        if torch.is_tensor(out):
            # MJY(20180330) HACK:: force nr_dims > 0
            if out.dim() == 0:
                outputs = [o.unsqueeze(0) for o in outputs]
            return Gather.apply(target_device, dim, *outputs)
        elif out is None:
            return None
        elif isinstance(out, collections.Mapping):
            return {k: gather_map([o[k] for o in outputs]) for k in out}
        elif isinstance(out, collections.Sequence):
            return type(out)(map(gather_map, zip(*outputs)))
    return gather_map(outputs) 
开发者ID:CSAILVision,项目名称:semantic-segmentation-pytorch,代码行数:21,代码来源:data_parallel.py

示例3: gather

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def gather(outputs, target_device, dim=0):
    r"""
    Gathers tensors from different GPUs on a specified device
      (-1 means the CPU).
    """
    def gather_map(outputs):
        out = outputs[0]
        if torch.is_tensor(out):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None
        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))
        return type(out)(map(gather_map, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        return gather_map(outputs)
    finally:
        gather_map = None 
开发者ID:Cadene,项目名称:bootstrap.pytorch,代码行数:26,代码来源:data_parallel.py

示例4: dict_gather_v1

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def dict_gather_v1(outputs, target_device, dim=0):
    """
    Gathers variables from different GPUs on a specified device
      (-1 means the CPU), with dictionary support.
    """
    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, Variable) or torch.is_tensor(out):
            if out.dim() == 0:
                outputs = [o.unsqueeze(0) for o in outputs]
            return Gather.apply(target_device, dim, *outputs)
        elif out is None:
            return None
        elif isinstance(out, collections.Mapping):
            return {k: gather_map([o[k] for o in outputs]) for k in out}
        elif isinstance(out, six.string_types):
            return outputs
        elif isinstance(out, collections.Sequence):
            return type(out)(map(gather_map, zip(*outputs)))
        return outputs
    return gather_map(outputs) 
开发者ID:vacancy,项目名称:Jacinle,代码行数:23,代码来源:dict_gather.py

示例5: tnn_gather

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def tnn_gather(outputs, target_device, dim=0):
    r"""
    Gathers variables from different GPUs on a specified device
      (-1 means the CPU).
    """
    def gather_map(outputs):
        if isinstance(outputs, Variable):
            if target_device == -1:
                return outputs.cpu()
            return outputs.cuda(target_device)

        out = outputs[0]
        if isinstance(out, Variable):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None

        if isinstance(out, ScatterList):
            return tuple(map(gather_map, itertools.chain(*outputs)))

        return type(out)(map(gather_map, zip(*outputs)))
    return gather_map(outputs) 
开发者ID:CharlesShang,项目名称:Detectron-PYTORCH,代码行数:24,代码来源:data_parallel.py

示例6: scatter

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def scatter(inputs, target_gpus, dim=0):
    r"""
    Slices tensors into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not tensors.
    """

    def scatter_map(obj):
        if isinstance(obj, torch.Tensor):
            return Scatter.apply(target_gpus, None, dim, obj)
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        if isinstance(obj, PackedSequence):
            return packed_sequence_scatter(obj, target_gpus)
        return [obj for _ in target_gpus]

    # After scatter_map is called, a scatter_map cell will exist. This cell
    # has a reference to the actual function scatter_map, which has references
    # to a closure that has a reference to the scatter_map cell (because the
    # fn is recursive). To avoid this reference cycle, we set the function to
    # None, clearing the cell
    try:
        return scatter_map(inputs)
    finally:
        scatter_map = None 
开发者ID:mapillary,项目名称:seamseg,代码行数:31,代码来源:scatter_gather.py

示例7: gather

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def gather(outputs, target_device, dim=0):
    r"""
    Gathers tensors from different GPUs on a specified device
      (-1 means the CPU).
    """

    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, torch.Tensor):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None
        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))
        if isinstance(out, PackedSequence):
            return packed_sequence_gather(outputs, target_device)
        return type(out)(map(gather_map, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        return gather_map(outputs)
    finally:
        gather_map = None 
开发者ID:mapillary,项目名称:seamseg,代码行数:29,代码来源:scatter_gather.py

示例8: gather_res

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def gather_res(outputs, target_device, dim=0):
    """
    Assuming the signatures are the same accross results!
    """
    out = outputs[0]
    args = {field: Gather.apply(target_device, dim, *[getattr(o, field) for o in outputs])
            for field, v in out.__dict__.items() if v is not None}
    return type(out)(**args) 
开发者ID:rowanz,项目名称:neural-motifs,代码行数:10,代码来源:object_detector.py

示例9: apply

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def apply(self, feed_dict, key):
        raise NotImplementedError() 
开发者ID:vacancy,项目名称:Jacinle,代码行数:4,代码来源:collate_v3.py

示例10: _stack_raw

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def _stack_raw(self, values, out, maybe_cuda, is_concat=False):
        if self.mode is VarLengthCollateV3Mode.GATHER and maybe_cuda:
            if values[0].dim() == 0:
                values = [o.unsqueeze(0) for o in values]
            return Gather.apply(self.gather_device, self.gather_dim, *values)
        else:
            if is_concat:
                return torch.cat(values, 0, out=out)
            else:
                return torch.stack(values, 0, out=out) 
开发者ID:vacancy,项目名称:Jacinle,代码行数:12,代码来源:collate_v3.py

示例11: scatter

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def scatter(inputs, target_gpus, dim=0):
    r"""
    Slices variables into approximately equal chunks and
    distributes them across given GPUs. Duplicates
    references to objects that are not variables. Does not
    support Tensors.
    """
    def scatter_map(obj):
        if isinstance(obj, Variable):
            # print('var')
            return Scatter.apply(target_gpus, None, dim, obj)
        assert not torch.is_tensor(obj), "Tensors not supported in scatter."
        if isinstance(obj, ScatterList):
            # print('target_gpus:', target_gpus, 'obj:', len(obj))
            # assert len(obj) == len(target_gpus)
            chunk_size = int(ceil(float(len(obj)) / float(len(target_gpus))))
            # print('scatterlist')
            # print (chunk_size, len(obj))
            return [obj[i*chunk_size: (i+1)*chunk_size] for i in range(len(target_gpus))]
        if isinstance(obj, tuple) and len(obj) > 0:
            return list(zip(*map(scatter_map, obj)))
        if isinstance(obj, list) and len(obj) > 0:
            return list(map(list, zip(*map(scatter_map, obj))))
        if isinstance(obj, dict) and len(obj) > 0:
            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
        # print('others')
        return [obj for targets in target_gpus]

    return scatter_map(inputs) 
开发者ID:CharlesShang,项目名称:Detectron-PYTORCH,代码行数:31,代码来源:data_parallel.py

示例12: __call__

# 需要导入模块: from torch.nn.parallel._functions import Gather [as 别名]
# 或者: from torch.nn.parallel._functions.Gather import apply [as 别名]
def __call__(self, batch, flatten_key=None, layout_spec=None):
        if flatten_key is not None and flatten_key in self.layout:
            layout_spec = self.layout[flatten_key]

        if layout_spec is not None and layout_spec.type is DataLayoutType.SKIP:
            return batch

        error_msg = "Batch must contain tensors, numbers, dicts or lists; found {}."
        elem_type = type(batch[0])
        if layout_spec is not None:
            assert (
                torch.is_tensor(batch[0]) or
                (elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' and elem_type.__name__ != 'string_')
            ), 'Invalid layout type for: {}.'.format(flatten_key)

        if torch.is_tensor(batch[0]):
            return self._stack(batch, layout_spec, maybe_cuda=True)
        elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
                and elem_type.__name__ != 'string_':
            elem = batch[0]
            if elem_type.__name__ == 'ndarray':
                # array of string classes and object
                if re.search('[SaUO]', elem.dtype.str) is not None:
                    raise TypeError(error_msg.format(elem.dtype))
                return self._stack([torch.from_numpy(b) for b in batch], layout_spec, maybe_cuda=False)
            if elem.shape == ():  # scalars
                py_type = float if elem.dtype.name.startswith('float') else int
                return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))

        elif isinstance(batch[0], int):
            return torch.LongTensor(batch)
        elif isinstance(batch[0], float):
            return torch.DoubleTensor(batch)
        elif isinstance(batch[0], string_types):
            return batch

        elif isinstance(batch[0], collections.Mapping):
            result = dict()
            for key in batch[0]:
                values = [d[key] for d in batch]
                next_key = key if flatten_key is None else f'{flatten_key}.{key}'
                values = self(values, flatten_key=next_key, layout_spec=layout_spec)
                if isinstance(values, _VarLengthCollateV3Stack):
                    values.apply(result, key)
                else:
                    result[key] = values
            return result
        elif isinstance(batch[0], collections.Sequence):
            transposed = zip(*batch)
            # Add .{index} only if it's inside a dict already.
            return [
                self(samples, flatten_key=None if flatten_key is None else f'{flatten_key}.{i}',
                     layout_spec=layout_spec)
                for i, samples in enumerate(transposed)
            ]

        raise TypeError((error_msg.format(type(batch[0])))) 
开发者ID:vacancy,项目名称:Jacinle,代码行数:59,代码来源:collate_v3.py


注:本文中的torch.nn.parallel._functions.Gather.apply方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。