當前位置: 首頁>>代碼示例>>Python>>正文


Python _utils._get_device_index方法代碼示例

本文整理匯總了Python中torch.cuda._utils._get_device_index方法的典型用法代碼示例。如果您正苦於以下問題:Python _utils._get_device_index方法的具體用法?Python _utils._get_device_index怎麽用?Python _utils._get_device_index使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在torch.cuda._utils的用法示例。


在下文中一共展示了_utils._get_device_index方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: __init__

# 需要導入模塊: from torch.cuda import _utils [as 別名]
# 或者: from torch.cuda._utils import _get_device_index [as 別名]
def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(EncodingParallel, self).__init__()

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device(
            "cuda {}".format(self.device_ids[0]))

        _check_balance(self.device_ids)

        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0]) 
開發者ID:PistonY,項目名稱:torch-toolbox,代碼行數:26,代碼來源:EncodingDataParallel.py

示例2: __init__

# 需要導入模塊: from torch.cuda import _utils [as 別名]
# 或者: from torch.cuda._utils import _get_device_index [as 別名]
def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallelImbalance, self).__init__(
            module, device_ids, output_device, dim)

        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return

        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]

        if not all(t.is_cuda and t.device.index == device_ids[0]
                   for t in chain(module.parameters(), module.buffers())):
            raise RuntimeError("module must have its parameters and buffers "
                               "on device %d (device_ids[0])" % device_ids[0])

        self.dim = dim
        self.module = module
        self.device_ids = list(
            map(lambda x: _get_device_index(x, True), device_ids))
        self.output_device = _get_device_index(output_device, True)

        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0]) 
開發者ID:microsoft,項目名稱:unilm,代碼行數:29,代碼來源:data_parallel.py

示例3: criterion_parallel_apply

# 需要導入模塊: from torch.cuda import _utils [as 別名]
# 或者: from torch.cuda._utils import _get_device_index [as 別名]
def criterion_parallel_apply(
        modules,
        inputs,
        targets,
        kwargs_tup=None,
        devices=None):
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup is not None:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)
    devices = list(map(lambda x: _get_device_index(x, True), devices))
    lock = threading.Lock()
    results = {}
    grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, target, kwargs, device=None):
        torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                if not isinstance(target, (list, tuple)):
                    target = (target,)
                output = module(*input, *target, **kwargs)
            with lock:
                results[i] = output
        except Exception:
            with lock:
                results[i] = ExceptionWrapper(
                    where="in replica {} on device {}".format(i, device))

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target, kwargs, device))
                   for i, (module, input, target, kwargs, device) in
                   enumerate(zip(modules, inputs, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, ExceptionWrapper):
            output.reraise()
        outputs.append(output)
    return outputs 
開發者ID:PistonY,項目名稱:torch-toolbox,代碼行數:61,代碼來源:EncodingDataParallel.py


注:本文中的torch.cuda._utils._get_device_index方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。