本文整理汇总了Python中torch.distributed.is_initialized方法的典型用法代码示例。如果您正苦于以下问题:Python distributed.is_initialized方法的具体用法?Python distributed.is_initialized怎么用?Python distributed.is_initialized使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类torch.distributed
的用法示例。
在下文中一共展示了distributed.is_initialized方法的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: average_across_processes
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]]):
r"""
Averages a tensor, or a dict of tensors across all processes in a process
group. Objects in all processes will finally have same mean value.
.. note::
Nested dicts of tensors are not supported.
Parameters
----------
t: torch.Tensor or Dict[str, torch.Tensor]
A tensor or dict of tensors to average across processes.
"""
if dist.is_initialized():
if isinstance(t, torch.Tensor):
dist.all_reduce(t, op=dist.ReduceOp.SUM)
t /= get_world_size()
elif isinstance(t, dict):
for k in t:
dist.all_reduce(t[k], op=dist.ReduceOp.SUM)
t[k] /= dist.get_world_size()
示例2: scatter
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def scatter(self, scatter_list, src, size=None, device=None):
"""Scatters a list of tensors to all parties."""
assert dist.is_initialized(), "initialize the communicator first"
if src != self.get_rank():
if size is None:
size = scatter_list[self.get_rank()].size()
if device is None:
try:
device = scatter_list[self.get_rank()].device
except Exception:
pass
tensor = torch.empty(size=size, dtype=torch.long, device=device)
dist.scatter(tensor, [], src, group=self.main_group)
else:
scatter_list = [s.data for s in scatter_list]
tensor = scatter_list[self.get_rank()]
dist.scatter(tensor, scatter_list, src, group=self.main_group)
return tensor
示例3: reduce
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def reduce(self, input, dst, op=ReduceOp.SUM, batched=False):
"""Reduces the input data across all parties."""
assert dist.is_initialized(), "initialize the communicator first"
if batched:
assert isinstance(input, list), "batched reduce input must be a list"
reqs = []
result = [x.clone().data for x in input]
for tensor in result:
reqs.append(
dist.reduce(
tensor, dst, op=op, group=self.main_group, async_op=True
)
)
for req in reqs:
req.wait()
else:
assert torch.is_tensor(
input.data
), "unbatched input for reduce must be a torch tensor"
result = input.clone()
dist.reduce(result.data, dst, op=op, group=self.main_group)
return result if dst == self.get_rank() else None
示例4: all_reduce
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def all_reduce(self, input, op=ReduceOp.SUM, batched=False):
"""Reduces the input data across all parties; all get the final result."""
assert dist.is_initialized(), "initialize the communicator first"
if batched:
assert isinstance(input, list), "batched reduce input must be a list"
reqs = []
result = [x.clone() for x in input]
for tensor in result:
reqs.append(
dist.all_reduce(
tensor.data, op=op, group=self.main_group, async_op=True
)
)
for req in reqs:
req.wait()
else:
assert torch.is_tensor(
input.data
), "unbatched input for reduce must be a torch tensor"
result = input.clone()
dist.all_reduce(result.data, op=op, group=self.main_group)
return result
示例5: reduce_mean
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def reduce_mean(tensor):
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor
示例6: _parse_losses
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary infomation.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
示例7: get_world_size
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
示例8: get_rank
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
示例9: synchronize
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
示例10: get_world_size
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
示例11: get_world_size
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_world_size():
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
示例12: get_rank
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def get_rank():
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
示例13: is_dist_avail_and_initialized
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
示例14: setup_logger
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def setup_logger(logpth):
logfile = 'Deeplab_v3plus-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
logfile = osp.join(logpth, logfile)
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
log_level = logging.INFO
if dist.is_initialized() and dist.get_rank()!=0:
log_level = logging.WARNING
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
logging.root.addHandler(logging.StreamHandler())
示例15: __init__
# 需要导入模块: from torch import distributed [as 别名]
# 或者: from torch.distributed import is_initialized [as 别名]
def __init__(self, cfg, *args, **kwargs):
self.cfg = cfg
self.distributed = dist.is_initialized()
## dataloader
dsval = CityScapes(cfg, mode='val')
sampler = None
if self.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(dsval)
self.dl = DataLoader(dsval,
batch_size = cfg.eval_batchsize,
sampler = sampler,
shuffle = False,
num_workers = cfg.eval_n_workers,
drop_last = False)