当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch skippable用法及代码示例


本文简要介绍python语言中 torch.distributed.pipeline.sync.skip.skippable.skippable 的用法。

用法:

torch.distributed.pipeline.sync.skip.skippable.skippable(stash=(), pop=())

用于定义带有跳过连接的 nn.Module 的装饰器。装饰模块称为“skippable”。即使模块未被 Pipe 包装,此函数也可以完美运行。

每个跳跃张量都由其名称管理。在操作跳过张量之前,可跳过的模块必须通过stash 和/或pop 参数静态声明跳过张量的名称。具有预先声明名称的跳过张量可以由 yield stash(name, tensor) 隐藏或由 tensor = yield pop(name) 弹出。

这是一个三层的例子。一个名为“1to3” 的跳跃张量分别在第一层和最后一层被隐藏和弹出:

@skippable(stash=['1to3'])
class Layer1(nn.Module):
    def forward(self, input):
        yield stash('1to3', input)
        return f1(input)

class Layer2(nn.Module):
    def forward(self, input):
        return f2(input)

@skippable(pop=['1to3'])
class Layer3(nn.Module):
    def forward(self, input):
        skip_1to3 = yield pop('1to3')
        return f3(input) + skip_1to3

model = nn.Sequential(Layer1(), Layer2(), Layer3())

一个可跳过的模块可以存储或弹出多个跳过张量:

@skippable(stash=['alice', 'bob'], pop=['carol'])
class StashStashPop(nn.Module):
    def forward(self, input):
        yield stash('alice', f_alice(input))
        yield stash('bob', f_bob(input))
        carol = yield pop('carol')
        return input + carol

每个跳过张量必须与一对 stashpop 关联。 Pipe 在包装模块时自动检查此限制。您还可以通过 verify_skippables() 检查限制,而不使用 Pipe

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.distributed.pipeline.sync.skip.skippable.skippable。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。