當前位置: 首頁>>代碼示例>>Python>>正文


Python amp.init方法代碼示例

本文整理匯總了Python中apex.amp.init方法的典型用法代碼示例。如果您正苦於以下問題:Python amp.init方法的具體用法?Python amp.init怎麽用?Python amp.init使用的例子?那麽, 這裏精選的方法代碼示例或許可以為您提供幫助。您也可以進一步了解該方法所在apex.amp的用法示例。


在下文中一共展示了amp.init方法的4個代碼示例,這些例子默認根據受歡迎程度排序。您可以為喜歡或者感覺有用的代碼點讚,您的評價將有助於係統推薦出更棒的Python代碼示例。

示例1: init_distributed

# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1):
    #try:
    #    mp.set_start_method('spawn')  # spawn, forkserver, and fork
    #except RuntimeError:
    #    pass

    try:
        if local_rank == -1:
            if init == "slurm":
                rank = int(os.environ['SLURM_PROCID'])
                world_size = int(os.environ['SLURM_NTASKS'])
                local_rank = int(os.environ['SLURM_LOCALID'])
                #maser_node = os.environ['SLURM_TOPOLOGY_ADDR']
                #maser_port = '23456'
            elif init == "ompi":
                rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
                world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
                local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

            if use_cuda:
                device = local_rank % torch.cuda.device_count()
                torch.cuda.set_device(device)
                print(f"set cuda device to cuda:{device}")

            master_node = os.environ["MASTER_ADDR"]
            master_port = os.environ["MASTER_PORT"]
            init_method = f"tcp://{master_node}:{master_port}"
            #init_method = "env://"
            dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank)
            print(f"initialized as {rank}/{world_size} via {init_method}")
        else:
            if use_cuda:
                torch.cuda.set_device(local_rank)
                print(f"set cuda device to cuda:{local_rank}")
            dist.init_process_group(backend=backend, init_method="env://")
            print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://")
    except Exception as e:
        print(f"initialized as single process") 
開發者ID:jinserk,項目名稱:pytorch-asr,代碼行數:40,代碼來源:trainer.py

示例2: get_amp_handle

# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def get_amp_handle(args):
    if not args.use_cuda:
        args.fp16 = False
    if args.fp16:
        from apex import amp
        amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False)
        return amp_handle
    else:
        return None 
開發者ID:jinserk,項目名稱:pytorch-asr,代碼行數:11,代碼來源:trainer.py

示例3: setUp

# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self) 
開發者ID:NVIDIA,項目名稱:apex,代碼行數:5,代碼來源:test_rnn.py

示例4: test_bce_is_float_with_allow_banned

# 需要導入模塊: from apex import amp [as 別名]
# 或者: from apex.amp import init [as 別名]
def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, allow_banned=True)
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
        self.bce_common(assertion) 
開發者ID:NVIDIA,項目名稱:apex,代碼行數:7,代碼來源:test_basic_casts.py


注:本文中的apex.amp.init方法示例由純淨天空整理自Github/MSDocs等開源代碼及文檔管理平台,相關代碼片段篩選自各路編程大神貢獻的開源項目,源碼版權歸原作者所有,傳播和使用請參考對應項目的License;未經允許,請勿轉載。