本文整理匯總了Python中torch.cuda._utils._get_device_index方法的典型用法代碼示例。如果您正苦於以下問題:Python _utils._get_device_index方法的具體用法?Python _utils._get_device_index怎麽用?Python _utils._get_device_index使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類torch.cuda._utils
的用法示例。
在下文中一共展示了_utils._get_device_index方法的3個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: __init__
# 需要導入模塊: from torch.cuda import _utils [as 別名]
# 或者: from torch.cuda._utils import _get_device_index [as 別名]
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(EncodingParallel, self).__init__()
if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = list(
map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.src_device_obj = torch.device(
"cuda {}".format(self.device_ids[0]))
_check_balance(self.device_ids)
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
示例2: __init__
# 需要導入模塊: from torch.cuda import _utils [as 別名]
# 或者: from torch.cuda._utils import _get_device_index [as 別名]
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallelImbalance, self).__init__(
module, device_ids, output_device, dim)
if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
if not all(t.is_cuda and t.device.index == device_ids[0]
for t in chain(module.parameters(), module.buffers())):
raise RuntimeError("module must have its parameters and buffers "
"on device %d (device_ids[0])" % device_ids[0])
self.dim = dim
self.module = module
self.device_ids = list(
map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
if len(self.device_ids) == 1:
self.module.cuda(device_ids[0])
示例3: criterion_parallel_apply
# 需要導入模塊: from torch.cuda import _utils [as 別名]
# 或者: from torch.cuda._utils import _get_device_index [as 別名]
def criterion_parallel_apply(
modules,
inputs,
targets,
kwargs_tup=None,
devices=None):
assert len(modules) == len(inputs)
assert len(targets) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, target, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
if not isinstance(input, (list, tuple)):
input = (input,)
if not isinstance(target, (list, tuple)):
target = (target,)
output = module(*input, *target, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(
where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, target, kwargs, device))
for i, (module, input, target, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs