本文简要介绍python语言中 torch.nn.utils.prune.global_unstructured
的用法。
用法:
torch.nn.utils.prune.global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs)
parameters(可迭代的(模块,名字)元组) -以全局方式修剪模型的参数,即在决定修剪哪些权重之前聚合所有权重。 module 必须是
nn.Module
类型,并且 name 必须是字符串。pruning_method(函数) -此模块中的有效修剪函数,或由用户实现的满足实施指南并具有
PRUNING_TYPE='unstructured'
的自定义函数。importance_scores(dict) -字典映射(模块,名称)元组到相应参数的重要性分数张量。张量应该与参数的形状相同,用于计算剪枝掩码。如果未指定或无,将使用该参数代替其重要性分数。
kwargs-其他关键字参数,例如: amount(int 或 float):要在指定参数中修剪的参数数量。如果是
float
,则应介于 0.0 和 1.0 之间,并表示要修剪的参数的比例。如果是int
,则表示要修剪的参数的绝对数量。
TypeError - 如果
PRUNING_TYPE != 'unstructured'
通过应用指定的
pruning_method
全局修剪与parameters
中的所有参数对应的张量。通过以下方式修改模块:添加一个名为
name+'_mask'
的命名缓冲区,该缓冲区对应于通过修剪方法应用于参数name
的二进制掩码。将参数
name
替换为其修剪版本,而原始(未修剪)参数存储在名为name+'_orig'
的新参数中。
注意
由于全局结构化修剪没有多大意义,除非规范被参数的大小标准化,我们现在将全局修剪的范围限制为非结构化方法。
例子
>>> net = nn.Sequential(OrderedDict([ ('first', nn.Linear(10, 4)), ('second', nn.Linear(4, 1)), ])) >>> parameters_to_prune = ( (net.first, 'weight'), (net.second, 'weight'), ) >>> prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=10, ) >>> print(sum(torch.nn.utils.parameters_to_vector(net.buffers()) == 0)) tensor(10, dtype=torch.uint8)
参数:
抛出:
相关用法
- Python PyTorch gumbel_softmax用法及代码示例
- Python PyTorch get_tokenizer用法及代码示例
- Python PyTorch gammainc用法及代码示例
- Python PyTorch gradient用法及代码示例
- Python PyTorch gammaincc用法及代码示例
- Python PyTorch greedy_partition用法及代码示例
- Python PyTorch gammaln用法及代码示例
- Python PyTorch get_gradients用法及代码示例
- Python PyTorch get_ignored_functions用法及代码示例
- Python PyTorch get_default_dtype用法及代码示例
- Python PyTorch gt用法及代码示例
- Python PyTorch gather用法及代码示例
- Python PyTorch gcd用法及代码示例
- Python PyTorch get_graph_node_names用法及代码示例
- Python PyTorch get_testing_overrides用法及代码示例
- Python PyTorch generate_sp_model用法及代码示例
- Python PyTorch gather_object用法及代码示例
- Python PyTorch ge用法及代码示例
- Python PyTorch frexp用法及代码示例
- Python PyTorch jvp用法及代码示例
- Python PyTorch cholesky用法及代码示例
- Python PyTorch vdot用法及代码示例
- Python PyTorch ELU用法及代码示例
- Python PyTorch ScaledDotProduct.__init__用法及代码示例
- Python PyTorch saved_tensors_hooks用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.prune.global_unstructured。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。