当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python PyTorch global_unstructured用法及代码示例


本文简要介绍python语言中 torch.nn.utils.prune.global_unstructured 的用法。

用法:

torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)

参数

  • parameters(可迭代的(模块,名字)元组) -以全局方式修剪模型的参数,即在决定修剪哪些权重之前聚合所有权重。 module 必须是 nn.Module 类型,并且 name 必须是字符串。

  • pruning_method(函数) -此模块中的有效修剪函数,或由用户实现的满足实施指南并具有 PRUNING_TYPE='unstructured' 的自定义函数。

  • importance_scores(dict) -字典映射(模块,名称)元组到相应参数的重要性分数张量。张量应该与参数的形状相同,用于计算剪枝掩码。如果未指定或无,将使用该参数代替其重要性分数。

  • kwargs-其他关键字参数,例如: amount(int 或 float):要在指定参数中修剪的参数数量。如果是 float ,则应介于 0.0 和 1.0 之间,并表示要修剪的参数的比例。如果是 int ,则表示要修剪的参数的绝对数量。

抛出

TypeError - 如果PRUNING_TYPE != 'unstructured'

通过应用指定的 pruning_method 全局修剪与 parameters 中的所有参数对应的张量。通过以下方式修改模块:

  1. 添加一个名为 name+'_mask' 的命名缓冲区,该缓冲区对应于通过修剪方法应用于参数 name 的二进制掩码。

  2. 将参数 name 替换为其修剪版本,而原始(未修剪)参数存储在名为 name+'_orig' 的新参数中。

注意

由于全局结构化修剪没有多大意义,除非规范被参数的大小标准化,我们现在将全局修剪的范围限制为非结构化方法。

例子

>>> net = nn.Sequential(OrderedDict([
        ('first', nn.Linear(10, 4)),
        ('second', nn.Linear(4, 1)),
    ]))
>>> parameters_to_prune = (
        (net.first, 'weight'),
        (net.second, 'weight'),
    )
>>> prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=10,
    )
>>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0))
tensor(10, dtype=torch.uint8)

相关用法


注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.prune.global_unstructured。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。