當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch skippable用法及代碼示例


本文簡要介紹python語言中 torch.distributed.pipeline.sync.skip.skippable.skippable 的用法。

用法:

torch.distributed.pipeline.sync.skip.skippable.skippable(stash=(), pop=())

用於定義帶有跳過連接的 nn.Module 的裝飾器。裝飾模塊稱為“skippable”。即使模塊未被 Pipe 包裝,此函數也可以完美運行。

每個跳躍張量都由其名稱管理。在操作跳過張量之前,可跳過的模塊必須通過stash 和/或pop 參數靜態聲明跳過張量的名稱。具有預先聲明名稱的跳過張量可以由 yield stash(name, tensor) 隱藏或由 tensor = yield pop(name) 彈出。

這是一個三層的例子。一個名為“1to3” 的跳躍張量分別在第一層和最後一層被隱藏和彈出:

@skippable(stash=['1to3'])
class Layer1(nn.Module):
    def forward(self, input):
        yield stash('1to3', input)
        return f1(input)

class Layer2(nn.Module):
    def forward(self, input):
        return f2(input)

@skippable(pop=['1to3'])
class Layer3(nn.Module):
    def forward(self, input):
        skip_1to3 = yield pop('1to3')
        return f3(input) + skip_1to3

model = nn.Sequential(Layer1(), Layer2(), Layer3())

一個可跳過的模塊可以存儲或彈出多個跳過張量:

@skippable(stash=['alice', 'bob'], pop=['carol'])
class StashStashPop(nn.Module):
    def forward(self, input):
        yield stash('alice', f_alice(input))
        yield stash('bob', f_bob(input))
        carol = yield pop('carol')
        return input + carol

每個跳過張量必須與一對 stashpop 關聯。 Pipe 在包裝模塊時自動檢查此限製。您還可以通過 verify_skippables() 檢查限製,而不使用 Pipe

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.distributed.pipeline.sync.skip.skippable.skippable。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。