當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch autocast用法及代碼示例

本文簡要介紹python語言中 torch.autocast 的用法。

用法:

class torch.autocast(device_type, enabled=True, **kwargs)

參數

  • device_type(string,必需的) -是否使用‘cuda’或‘cpu’設備

  • enabled(bool,可選的,默認=真) -是否應在區域中啟用自動投射。

  • dtype(torch_dpython:類型,可選的) -是使用 torch.float16 還是 torch.bfloat16。

  • cache_enabled(bool,可選的,默認=真) -是否應該啟用 autocast 中的權重緩存。

autocast 的實例用作上下文管理器或裝飾器,允許腳本區域以混合精度運行。

在這些區域中,ops 在 autocast 選擇的 op-specific dtype 中運行,以提高性能同時保持準確性。有關詳細信息,請參閱 Autocast Op 參考。

進入autocast-enabled 區域時,張量可以是任何類型。使用自動投射時,不應在模型或輸入上調用 half()bfloat16()

autocast 應僅包含網絡的前向傳遞,包括損失計算。不建議在自動施法下向後傳遞。後向操作的運行類型與自動轉換用於相應前向操作的類型相同。

CUDA 設備示例:

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # Enables autocasting for the forward pass (model + loss)
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward()
    loss.backward()
    optimizer.step()

有關更複雜場景(例如梯度懲罰、多個模型/損失、自定義 autograd 函數)中的用法(以及梯度縮放),請參閱自動混合精度示例。

autocast 也可以用作裝飾器,例如,在模型的 forward 方法上:

class AutocastModel(nn.Module):
    ...
    @autocast()
    def forward(self, input):
        ...

在 autocast-enabled 區域中產生的浮點張量可能是 float16 。返回到 autocast-disabled 區域後,將它們與不同 dtype 的浮點張量一起使用可能會導致類型不匹配錯誤。如果是這樣,請將自動投射區域中生成的張量轉換回 float32(或其他 dtype,如果需要)。如果來自自動投射區域的張量已經是 float32 ,則投射是 no-op,並且不會產生額外的開銷。 CUDA 示例:

# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

with autocast():
    # torch.mm is on autocast's list of ops that should run in float16.
    # Inputs are float32, but the op runs in float16 and produces float16 output.
    # No manual casts are required.
    e_float16 = torch.mm(a_float32, b_float32)
    # Also handles mixed input types
    f_float16 = torch.mm(d_float32, e_float16)

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())

中央處理器示例:

# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
c_float32 = torch.rand((8, 8), device="cpu")
d_float32 = torch.rand((8, 8), device="cpu")

with autocast(dtype=torch.bfloat16, device_type="cpu"):
    # torch.mm is on autocast's list of ops that should run in bfloat16.
    # Inputs are float32, but the op runs in bfloat16 and produces bfloat16 output.
    # No manual casts are required.
    e_bfloat16 = torch.mm(a_float32, b_float32)
    # Also handles mixed input types
    f_bfloat16 = torch.mm(d_float32, e_bfloat16)

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_bfloat16.float())

autocast-enabled 區域中的類型不匹配錯誤是一個錯誤;如果這是您觀察到的,請提出問題。

autocast(enabled=False) 子區域可以嵌套在 autocast-enabled 區域中。例如,如果您想強製子區域在特定的 dtype 中運行,則本地禁用自動廣播可能很有用。禁用自動轉換使您可以顯式控製執行類型。在子區域中,來自周圍區域的輸入應在使用前轉換為dtype

# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

with autocast():
    e_float16 = torch.mm(a_float32, b_float32)
    with autocast(enabled=False):
        # Calls e_float16.float() to ensure float32 execution
        # (necessary because e_float16 was created in an autocasted region)
        f_float32 = torch.mm(c_float32, e_float16.float())

    # No manual casts are required when re-entering the autocast-enabled region.
    # torch.mm again runs in float16 and produces float16 output, regardless of input types.
    g_float16 = torch.mm(d_float32, f_float32)

自動轉換狀態是線程本地的。如果您希望在新線程中啟用它,則必須在該線程中調用上下文管理器或裝飾器。當每個進程使用多個 GPU 時,這會影響 torch.nn.DataParallel torch.nn.parallel.DistributedDataParallel (請參閱使用多個 GPU)。

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.autocast。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。