本文整理匯總了Python中apex.amp.init方法的典型用法代碼示例。如果您正苦於以下問題:Python amp.init方法的具體用法?Python amp.init怎麽用?Python amp.init使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在類apex.amp
的用法示例。
在下文中一共展示了amp.init方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。
示例1: init_distributed
# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1):
#try:
# mp.set_start_method('spawn') # spawn, forkserver, and fork
#except RuntimeError:
# pass
try:
if local_rank == -1:
if init == "slurm":
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NTASKS'])
local_rank = int(os.environ['SLURM_LOCALID'])
#maser_node = os.environ['SLURM_TOPOLOGY_ADDR']
#maser_port = '23456'
elif init == "ompi":
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
if use_cuda:
device = local_rank % torch.cuda.device_count()
torch.cuda.set_device(device)
print(f"set cuda device to cuda:{device}")
master_node = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
init_method = f"tcp://{master_node}:{master_port}"
#init_method = "env://"
dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank)
print(f"initialized as {rank}/{world_size} via {init_method}")
else:
if use_cuda:
torch.cuda.set_device(local_rank)
print(f"set cuda device to cuda:{local_rank}")
dist.init_process_group(backend=backend, init_method="env://")
print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://")
except Exception as e:
print(f"initialized as single process")
示例2: get_amp_handle
# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def get_amp_handle(args):
if not args.use_cuda:
args.fp16 = False
if args.fp16:
from apex import amp
amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False)
return amp_handle
else:
return None
示例3: setUp
# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
示例4: test_bce_is_float_with_allow_banned
# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion)