本文簡要介紹python語言中 torch.nn.utils.parametrize.register_parametrization 的用法。
用法:
torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)module(torch.nn.Module) -注冊參數化的模塊
tensor_name(str) -在其上注冊參數化的參數或緩衝區的名稱
parametrization(torch.nn.Module) -要注冊的參數化
unsafe(bool) -一個布爾標誌,表示參數化是否可以改變張量的 dtype 和形狀。默認值:
False警告:注冊時不檢查參數化的一致性。啟用此標誌需要您自擔風險。ValueError - 如果模塊沒有參數或名為
tensor_name的緩衝區向模塊中的張量添加參數化。
為簡單起見,假設
tensor_name="weight"。訪問module.weight時,模塊將返回參數化版本parametrization(module.weight)。如果原始張量需要梯度,則反向傳播將通過parametrization進行微分,優化器將相應地更新張量。模塊第一次注冊參數化時,此函數將屬性
parametrizations添加到類型為ParametrizationList的模塊。張量
weight的參數化列表可在module.parametrizations.weight下訪問。可以在
module.parametrizations.weight.original下訪問原始張量。可以通過在同一屬性上注冊多個參數化來連接參數化。
注冊參數化的訓練模式在注冊時更新以匹配主機模塊的訓練模式
參數化參數和緩衝區有一個內置的緩存係統,可以使用上下文管理器
cached()激活。parametrization可以選擇實現帶有簽名的方法def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]當注冊第一個參數化以計算原始張量的初始值時,在未參數化的張量上調用此方法。如果不實現此方法,則原始張量將隻是未參數化的張量。
如果在張量上注冊的所有參數化都實現了
right_inverse,則可以通過分配給它來初始化參數化張量,如下例所示。第一個參數化可能取決於多個輸入。這可以通過從
right_inverse返回張量元組來實現(參見下麵的RankOne參數化的示例實現)。在這種情況下,不受約束的張量也位於
module.parametrizations.weight下,名稱為original0,original1,...注意
如果 unsafe=False(默認),forward 和 right_inverse 方法都將被調用一次以執行許多一致性檢查。如果 unsafe=True,則如果張量未參數化,則將調用 right_inverse,否則不會調用任何內容。
注意
在大多數情況下,
right_inverse將是一個函數,例如forward(right_inverse(X)) == X(參見 right inverse )。有時,當參數化不是滿射時,放寬它可能是合理的。警告
如果參數化取決於多個輸入,
register_parametrization()將注冊許多新參數。如果在創建優化器後注冊此類參數化,則需要手動將這些新參數添加到優化器中。請參閱torch.Optimizer.add_param_group()。例子
>>> import torch >>> import torch.nn as nn >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True>>> class RankOne(nn.Module): >>> def forward(self, x, y): >>> # Form a rank 1 matrix multiplying two vectors >>> return x.unsqueeze(-1) @ y.unsqueeze(-2) >>> >>> def right_inverse(self, Z): >>> # Project Z onto the rank 1 matrices >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False) >>> # Return rescaled singular vectors >>> s0_sqrt = S[0].sqrt().unsqueeze(-1) >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt >>> >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) 1
參數:
關鍵字參數:
拋出:
相關用法
- Python PyTorch register_kl用法及代碼示例
- Python PyTorch register_module_forward_pre_hook用法及代碼示例
- Python PyTorch register_module_full_backward_hook用法及代碼示例
- Python PyTorch register_module_forward_hook用法及代碼示例
- Python PyTorch renorm用法及代碼示例
- Python PyTorch reshape用法及代碼示例
- Python PyTorch real用法及代碼示例
- Python PyTorch repeat_interleave用法及代碼示例
- Python PyTorch remove用法及代碼示例
- Python PyTorch read_vec_flt_ark用法及代碼示例
- Python PyTorch read_vec_int_ark用法及代碼示例
- Python PyTorch resolve_neg用法及代碼示例
- Python PyTorch remainder用法及代碼示例
- Python PyTorch remote用法及代碼示例
- Python PyTorch remove_spectral_norm用法及代碼示例
- Python PyTorch record用法及代碼示例
- Python PyTorch remove_weight_norm用法及代碼示例
- Python PyTorch retinanet_resnet50_fpn用法及代碼示例
- Python PyTorch read_vec_flt_scp用法及代碼示例
- Python PyTorch resolve_conj用法及代碼示例
- Python PyTorch reciprocal用法及代碼示例
- Python PyTorch result_type用法及代碼示例
- Python PyTorch replace_pattern用法及代碼示例
- Python PyTorch read_mat_scp用法及代碼示例
- Python PyTorch read_mat_ark用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.nn.utils.parametrize.register_parametrization。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。
