当前位置: 首页>>代码示例>>Python>>正文


Python amp.init方法代码示例

本文整理汇总了Python中apex.amp.init方法的典型用法代码示例。如果您正苦于以下问题:Python amp.init方法的具体用法?Python amp.init怎么用?Python amp.init使用的例子?那么恭喜您, 这里精选的方法代码示例或许可以为您提供帮助。您也可以进一步了解该方法所在apex.amp的用法示例。


在下文中一共展示了amp.init方法的4个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于系统推荐出更棒的Python代码示例。

示例1: init_distributed

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def init_distributed(use_cuda, backend="nccl", init="slurm", local_rank=-1):
    #try:
    #    mp.set_start_method('spawn')  # spawn, forkserver, and fork
    #except RuntimeError:
    #    pass

    try:
        if local_rank == -1:
            if init == "slurm":
                rank = int(os.environ['SLURM_PROCID'])
                world_size = int(os.environ['SLURM_NTASKS'])
                local_rank = int(os.environ['SLURM_LOCALID'])
                #maser_node = os.environ['SLURM_TOPOLOGY_ADDR']
                #maser_port = '23456'
            elif init == "ompi":
                rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
                world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
                local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

            if use_cuda:
                device = local_rank % torch.cuda.device_count()
                torch.cuda.set_device(device)
                print(f"set cuda device to cuda:{device}")

            master_node = os.environ["MASTER_ADDR"]
            master_port = os.environ["MASTER_PORT"]
            init_method = f"tcp://{master_node}:{master_port}"
            #init_method = "env://"
            dist.init_process_group(backend=backend, init_method=init_method, world_size=world_size, rank=rank)
            print(f"initialized as {rank}/{world_size} via {init_method}")
        else:
            if use_cuda:
                torch.cuda.set_device(local_rank)
                print(f"set cuda device to cuda:{local_rank}")
            dist.init_process_group(backend=backend, init_method="env://")
            print(f"initialized as {dist.get_rank()}/{dist.get_world_size()} via env://")
    except Exception as e:
        print(f"initialized as single process") 
开发者ID:jinserk,项目名称:pytorch-asr,代码行数:40,代码来源:trainer.py

示例2: get_amp_handle

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def get_amp_handle(args):
    if not args.use_cuda:
        args.fp16 = False
    if args.fp16:
        from apex import amp
        amp_handle = amp.init(enabled=True, enable_caching=True, verbose=False)
        return amp_handle
    else:
        return None 
开发者ID:jinserk,项目名称:pytorch-asr,代码行数:11,代码来源:trainer.py

示例3: setUp

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self) 
开发者ID:NVIDIA,项目名称:apex,代码行数:5,代码来源:test_rnn.py

示例4: test_bce_is_float_with_allow_banned

# 需要导入模块: from apex import amp [as 别名]
# 或者: from apex.amp import init [as 别名]
def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, allow_banned=True)
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
        self.bce_common(assertion) 
开发者ID:NVIDIA,项目名称:apex,代码行数:7,代码来源:test_basic_casts.py


注:本文中的apex.amp.init方法示例由纯净天空整理自Github/MSDocs等开源代码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。