本文简要介绍python语言中 torch.nn.utils.parametrize.register_parametrization 的用法。
用法:
torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)module(torch.nn.Module) -注册参数化的模块
tensor_name(str) -在其上注册参数化的参数或缓冲区的名称
parametrization(torch.nn.Module) -要注册的参数化
unsafe(bool) -一个布尔标志,表示参数化是否可以改变张量的 dtype 和形状。默认值:
False警告:注册时不检查参数化的一致性。启用此标志需要您自担风险。ValueError - 如果模块没有参数或名为
tensor_name的缓冲区向模块中的张量添加参数化。
为简单起见,假设
tensor_name="weight"。访问module.weight时,模块将返回参数化版本parametrization(module.weight)。如果原始张量需要梯度,则反向传播将通过parametrization进行微分,优化器将相应地更新张量。模块第一次注册参数化时,此函数将属性
parametrizations添加到类型为ParametrizationList的模块。张量
weight的参数化列表可在module.parametrizations.weight下访问。可以在
module.parametrizations.weight.original下访问原始张量。可以通过在同一属性上注册多个参数化来连接参数化。
注册参数化的训练模式在注册时更新以匹配主机模块的训练模式
参数化参数和缓冲区有一个内置的缓存系统,可以使用上下文管理器
cached()激活。parametrization可以选择实现带有签名的方法def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]当注册第一个参数化以计算原始张量的初始值时,在未参数化的张量上调用此方法。如果不实现此方法,则原始张量将只是未参数化的张量。
如果在张量上注册的所有参数化都实现了
right_inverse,则可以通过分配给它来初始化参数化张量,如下例所示。第一个参数化可能取决于多个输入。这可以通过从
right_inverse返回张量元组来实现(参见下面的RankOne参数化的示例实现)。在这种情况下,不受约束的张量也位于
module.parametrizations.weight下,名称为original0,original1,...注意
如果 unsafe=False(默认),forward 和 right_inverse 方法都将被调用一次以执行许多一致性检查。如果 unsafe=True,则如果张量未参数化,则将调用 right_inverse,否则不会调用任何内容。
注意
在大多数情况下,
right_inverse将是一个函数,例如forward(right_inverse(X)) == X(参见 right inverse )。有时,当参数化不是满射时,放宽它可能是合理的。警告
如果参数化取决于多个输入,
register_parametrization()将注册许多新参数。如果在创建优化器后注册此类参数化,则需要手动将这些新参数添加到优化器中。请参阅torch.Optimizer.add_param_group()。例子
>>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True>>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1
参数:
关键字参数:
抛出:
相关用法
- Python PyTorch register_kl用法及代码示例
- Python PyTorch register_module_forward_pre_hook用法及代码示例
- Python PyTorch register_module_full_backward_hook用法及代码示例
- Python PyTorch register_module_forward_hook用法及代码示例
- Python PyTorch renorm用法及代码示例
- Python PyTorch reshape用法及代码示例
- Python PyTorch real用法及代码示例
- Python PyTorch repeat_interleave用法及代码示例
- Python PyTorch remove用法及代码示例
- Python PyTorch read_vec_flt_ark用法及代码示例
- Python PyTorch read_vec_int_ark用法及代码示例
- Python PyTorch resolve_neg用法及代码示例
- Python PyTorch remainder用法及代码示例
- Python PyTorch remote用法及代码示例
- Python PyTorch remove_spectral_norm用法及代码示例
- Python PyTorch record用法及代码示例
- Python PyTorch remove_weight_norm用法及代码示例
- Python PyTorch retinanet_resnet50_fpn用法及代码示例
- Python PyTorch read_vec_flt_scp用法及代码示例
- Python PyTorch resolve_conj用法及代码示例
- Python PyTorch reciprocal用法及代码示例
- Python PyTorch result_type用法及代码示例
- Python PyTorch replace_pattern用法及代码示例
- Python PyTorch read_mat_scp用法及代码示例
- Python PyTorch read_mat_ark用法及代码示例
注:本文由纯净天空筛选整理自pytorch.org大神的英文原创作品 torch.nn.utils.parametrize.register_parametrization。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
