當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python PyTorch global_unstructured用法及代碼示例


本文簡要介紹python語言中 torch.nn.utils.prune.global_unstructured 的用法。

用法:

torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)

參數

  • parameters(可迭代的(模塊,名字)元組) -以全局方式修剪模型的參數,即在決定修剪哪些權重之前聚合所有權重。 module 必須是 nn.Module 類型,並且 name 必須是字符串。

  • pruning_method(函數) -此模塊中的有效修剪函數,或由用戶實現的滿足實施指南並具有 PRUNING_TYPE='unstructured' 的自定義函數。

  • importance_scores(dict) -字典映射(模塊,名稱)元組到相應參數的重要性分數張量。張量應該與參數的形狀相同,用於計算剪枝掩碼。如果未指定或無,將使用該參數代替其重要性分數。

  • kwargs-其他關鍵字參數,例如: amount(int 或 float):要在指定參數中修剪的參數數量。如果是 float ,則應介於 0.0 和 1.0 之間,並表示要修剪的參數的比例。如果是 int ,則表示要修剪的參數的絕對數量。

拋出

TypeError - 如果PRUNING_TYPE != 'unstructured'

通過應用指定的 pruning_method 全局修剪與 parameters 中的所有參數對應的張量。通過以下方式修改模塊:

  1. 添加一個名為 name+'_mask' 的命名緩衝區,該緩衝區對應於通過修剪方法應用於參數 name 的二進製掩碼。

  2. 將參數 name 替換為其修剪版本,而原始(未修剪)參數存儲在名為 name+'_orig' 的新參數中。

注意

由於全局結構化修剪沒有多大意義,除非規範被參數的大小標準化,我們現在將全局修剪的範圍限製為非結構化方法。

例子

>>> net = nn.Sequential(OrderedDict([
        ('first', nn.Linear(10, 4)),
        ('second', nn.Linear(4, 1)),
    ]))
>>> parameters_to_prune = (
        (net.first, 'weight'),
        (net.second, 'weight'),
    )
>>> prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=10,
    )
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
tensor(10, dtype=torch.uint8)

相關用法


注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.prune.global_unstructured。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。