當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch custom_from_mask用法及代碼示例


本文簡要介紹python語言中 torch.nn.utils.prune.custom_from_mask 的用法。

用法:

torch.nn.utils.prune.custom_from_mask(module, name, mask)

參數

  • module(torch.nn.Module) -包含要修剪的張量的模塊

  • name(str) -module 中的參數名稱,將對其進行修剪。

  • mask(Tensor) -要應用於參數的二進製掩碼。

返回

輸入模塊的修改(即修剪)版本

返回類型

模塊(nn.Module)

通過在 mask 中應用預先計算的掩碼,修剪與 module 中名為 name 的參數對應的張量。通過以下方式修改模塊(並返回修改後的模塊):

  1. 添加一個名為 name+'_mask' 的命名緩衝區,該緩衝區對應於通過修剪方法應用於參數 name 的二進製掩碼。

  2. 將參數 name 替換為其修剪版本,而原始(未修剪)參數存儲在名為 name+'_orig' 的新參數中。

例子

>>> m = prune.custom_from_mask(
        nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
    )
>>> print(m.bias_mask)
tensor([0., 1., 0.])

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.prune.custom_from_mask。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。