本文簡要介紹python語言中 torch.svd
的用法。
用法:
torch.svd(input, some=True, compute_uv=True, *, out=None)
out(tuple,可選的) -張量的輸出元組
計算矩陣或矩陣批次
input
的奇異值分解。奇異值分解表示為命名元組(U, S, V)
,例如input
。其中 是實數輸入的V
的轉置,以及複雜輸入的V
的共軛轉置。如果input
是一批矩陣,則U
、S
和V
也使用與input
相同的批量維度進行批處理。如果
some
是True
(默認),則該方法返回簡化的奇異值分解。在這種情況下,如果input
的最後兩個維度是m
和n
,則返回的U
和V
矩陣將僅包含min(n, m)
正交列。如果
compute_uv
是False
,則返回的U
和V
將分別是形狀為(m, m)
和(n, n)
的 zero-filled 矩陣,並且與input
的設備相同。當compute_uv
為False
時,參數some
無效。支持
input
的 float、double、cfloat 和 cdouble 數據類型。U
和V
的 dtypes 與input
的相同。S
將始終為實值,即使input
很複雜。警告
torch.svd()
已棄用,取而代之的是torch.linalg.svd()
,並將在未來的 PyTorch 版本中刪除。U, S, V = torch.svd(A, some=some, compute_uv=True)
(默認)應替換為U, S, Vh = torch.linalg.svd(A, full_matrices=not some) V = Vh.transpose(-2, -1).conj()
_, S, _ = torch.svd(A, some=some, compute_uv=False)
應該替換為S = torch.svdvals(A)
注意
與
torch.linalg.svd()
的區別:some
與torch.linalg.svd()
的full_matrices
相反。請注意,兩者的默認值都是True
,因此默認行為實際上是相反的。torch.svd()
返回V
,而torch.linalg.svd()
返回Vh
,即 。如果
compute_uv
是False
,則torch.svd()
返回U
和Vh
的 zero-filled 張量,而torch.linalg.svd()
返回空張量。
注意
奇異值按降序返回。如果
input
是一批矩陣,則以降序返回該批中每個矩陣的奇異值。注意
如果
compute_uv
是True
,則S
張量隻能用於計算梯度。注意
當
some
為False
時,U[…, :, min(m, n):]
和V[…, :, min(m, n):]
上的梯度將在後向傳遞中被忽略,因為這些向量可以是相應子空間的任意基。注意
CPU 上
torch.linalg.svd()
的實現使用 LAPACK 的例程?gesdd
(一種分治算法)而不是?gesvd
以提高速度。類似地,在 GPU 上,它在 CUDA 10.1.243 及更高版本上使用 cuSOLVER 的例程gesvdj
和gesvdjBatched
,在早期版本的 CUDA 上使用 MAGMA 的例程gesdd
。注意
返回的
U
將不連續。矩陣(或矩陣批次)將表示為列主矩陣(即Fortran-contiguous)。警告
關於
U
和V
的梯度隻有在輸入不具有零或重複奇異值時才會是有限的。警告
如果任何兩個奇異值之間的距離接近於零,則相對於
U
和V
的梯度將在數值上不穩定,因為它們取決於 。當矩陣具有小的奇異值時也會發生同樣的情況,因為這些梯度也取決於S⁻¹
。警告
對於 complex-valued
input
,奇異值分解不是唯一的,因為U
和V
可以在每一列上乘以任意相位因子 。當input
具有重複的奇異值時也會發生同樣的情況,其中可以將U
和V
中的跨越子空間的列乘以旋轉矩陣和 the resulting vectors will span the same subspace 。不同的平台,如 NumPy,或不同設備類型上的輸入,可能會產生不同的U
和V
張量。例子:
>>> a = torch.randn(5, 3) >>> a tensor([[ 0.2364, -0.7752, 0.6372], [ 1.7201, 0.7394, -0.0504], [-0.3371, -1.0584, 0.5296], [ 0.3550, -0.4022, 1.5569], [ 0.2445, -0.0158, 1.1414]]) >>> u, s, v = torch.svd(a) >>> u tensor([[ 0.4027, 0.0287, 0.5434], [-0.1946, 0.8833, 0.3679], [ 0.4296, -0.2890, 0.5261], [ 0.6604, 0.2717, -0.2618], [ 0.4234, 0.2481, -0.4733]]) >>> s tensor([2.3289, 2.0315, 0.7806]) >>> v tensor([[-0.0199, 0.8766, 0.4809], [-0.5080, 0.4054, -0.7600], [ 0.8611, 0.2594, -0.4373]]) >>> torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t())) tensor(8.6531e-07) >>> a_big = torch.randn(7, 5, 3) >>> u, s, v = torch.svd(a_big) >>> torch.dist(a_big, torch.matmul(torch.matmul(u, torch.diag_embed(s)), v.transpose(-2, -1))) tensor(2.6503e-06)
參數:
關鍵字參數:
相關用法
- Python PyTorch svdvals用法及代碼示例
- Python PyTorch svd用法及代碼示例
- Python PyTorch saved_tensors_hooks用法及代碼示例
- Python PyTorch sqrt用法及代碼示例
- Python PyTorch skippable用法及代碼示例
- Python PyTorch squeeze用法及代碼示例
- Python PyTorch square用法及代碼示例
- Python PyTorch save_on_cpu用法及代碼示例
- Python PyTorch scatter_object_list用法及代碼示例
- Python PyTorch skip_init用法及代碼示例
- Python PyTorch simple_space_split用法及代碼示例
- Python PyTorch sum用法及代碼示例
- Python PyTorch sub用法及代碼示例
- Python PyTorch sparse_csr_tensor用法及代碼示例
- Python PyTorch sentencepiece_numericalizer用法及代碼示例
- Python PyTorch symeig用法及代碼示例
- Python PyTorch sinh用法及代碼示例
- Python PyTorch sinc用法及代碼示例
- Python PyTorch std_mean用法及代碼示例
- Python PyTorch spectral_norm用法及代碼示例
- Python PyTorch slogdet用法及代碼示例
- Python PyTorch symbolic_trace用法及代碼示例
- Python PyTorch shutdown用法及代碼示例
- Python PyTorch sgn用法及代碼示例
- Python PyTorch set_flush_denormal用法及代碼示例
注:本文由純淨天空篩選整理自pytorch.org大神的英文原創作品 torch.svd。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。