本文整理汇总了Python中apex.amp.init方法的典型用法代码示例。如果您正苦于以下问题:Python amp.init方法的具体用法?Python amp.init怎么用?Python amp.init使用的例子?那么, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在类apex.amp
的用法示例。
在下文中一共展示了amp.init方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。
示例1: init_distributed
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1):
#try:
# mp.set_start_method('spawn') # spawn, forkserver, and fork
#except RuntimeError:
# pass
try:
if local_rank == -1:
if init == "slurm":
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NTASKS'])
local_rank = int(os.environ['SLURM_LOCALID'])
#maser_node = os.environ['SLURM_TOPOLOGY_ADDR']
#maser_port = '23456'
elif init == "ompi":
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
if use_cuda:
device = local_rank % torch.cuda.device_count()
torch.cuda.set_device(device)
print(f"set cuda device to cuda:{device}")
master_node = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
init_method = f"tcp://{master_node}:{master_port}"
#init_method = "env://"
dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank)
print(f"initialized as {rank}/{world_size} via {init_method}")
else:
if use_cuda:
torch.cuda.set_device(local_rank)
print(f"set cuda device to cuda:{local_rank}")
dist.init_process_group(backend=backend, init_method="env://")
print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://")
except Exception as e:
print(f"initialized as single process")
示例2: get_amp_handle
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def get_amp_handle(args):
if not args.use_cuda:
args.fp16 = False
if args.fp16:
from apex import amp
amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False)
return amp_handle
else:
return None
示例3: setUp
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
示例4: test_bce_is_float_with_allow_banned
# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion)