當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python PyTorch ln_structured用法及代碼示例

本文簡要介紹python語言中 torch.nn.utils.prune.ln_structured 的用法。

用法:

torch.nn.utils.prune.ln_structured(module, name, amount, n, dim, importance_scores=None)

參數

  • module(torch.nn.Module) -包含要修剪的張量的模塊

  • name(str) -module 中的參數名稱,將對其進行修剪。

  • amount(int或者float) -要修剪的參數數量。如果 float ,應介於 0.0 和 1.0 之間,表示要修剪的參數比例。如果 int ,它表示要修剪的參數的絕對數量。

  • n(int,float,inf,-inf,'fro','nuc') - 請參閱有效條目的文檔以獲取參數ptorch.norm.

  • dim(int) -dim 的索引,我們沿著該索引定義要修剪的通道。

  • importance_scores(torch.Tensor) -用於計算剪枝掩碼的重要性分數(與模塊參數形狀相同)的張量。此張量中的值表示正在修剪的參數中相應元素的重要性。如果未指定或無,將使用模塊參數代替它。

返回

輸入模塊的修改(即修剪)版本

返回類型

模塊(nn.Module)

修剪與module 中名為name 的參數相對應的張量,方法是沿著具有最低L n -norm 的指定dim 移除指定的amount(當前未修剪的)通道。通過以下方式修改模塊(並返回修改後的模塊):

  1. 添加一個名為 name+'_mask' 的命名緩衝區,該緩衝區對應於通過修剪方法應用於參數 name 的二進製掩碼。

  2. 將參數 name 替換為其修剪版本,而原始(未修剪)參數存儲在名為 name+'_orig' 的新參數中。

例子

>>> m = prune.ln_structured(
       nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
    )

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.prune.ln_structured。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。