当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch fp16_compress_wrapper用法及代码示例


本文简要介绍python语言中 torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper 的用法。

用法:

torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)

此包装器将给定 DDP 通信挂钩的输入梯度张量转换为 half-precision 浮点格式 (torch.float16),并将给定挂钩的结果张量转换回输入数据类型,例如 float32

因此,fp16_compress_hook 等价于 fp16_compress_wrapper(allreduce_hook)

例子:

>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
>>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。