本文簡要介紹 python 語言中 scipy.stats.wasserstein_distance
的用法。
用法:
scipy.stats.wasserstein_distance(u_values, v_values, u_weights=None, v_weights=None)#
計算兩個離散分布之間的 Wasserstein-1 距離。
Wasserstein 距離,也稱為地球移動器距離或最佳運輸距離,是兩個概率分布之間的相似性度量。在離散情況下,Wasserstein 距離可以理解為將一種分布轉換為另一種分布的最優傳輸計劃的成本。成本計算為移動的概率質量量與其移動距離的乘積。簡短直觀的介紹可以在[2]中找到。
- u_values: 一維或二維數組
來自概率分布或概率分布的支持(所有可能值的集合)的樣本。沿軸 0 的每個元素都是觀察值或可能值。如果是二維的,軸 1 表示分布的維數;即,每一行都是一個向量觀察值或可能值。
- v_values: 一維或二維數組
第二個發行版的樣本或支持。
- u_weights, v_weights: 一維數組,可選
與樣本相對應的權重或計數或與支持值相對應的概率質量。元素之和必須為正且有限。如果未指定,則為每個值分配相同的權重。
- distance: 浮點數
分布之間的計算距離。
參數 ::
返回 ::
注意:
給定兩個概率質量函數 和 ,分布之間的第一個 Wasserstein 距離為:
其中 是 的(概率)分布集,其邊際分別為第一個和第二個因子的 和 。對於給定值 , 給出 在位置 的概率,對於 也是如此。
在一維情況下,令 和 分別表示 和 的CDF,這個距離也等於:
有關這兩個定義的等效性的證明,請參閱[3]。
在更一般(更高維)和離散的情況下,它也稱為最優傳輸問題或蒙日問題。令有限點集 和 分別表示概率質量函數 和 的支持集。蒙日問題可以表示如下:
令 表示運輸計劃, 表示距離矩陣,
函數表示矢量化函數,通過垂直堆疊矩陣的列將矩陣轉換為列向量。傳送計劃 是矩陣 ,其中 是表示從 傳送到 的概率質量的量的正值。對 的行求和應給出源分布 : 對所有 均成立,對 的列求和應給出目標分布 : 對所有均成立 。距離矩陣 是矩陣 ,其中 。
給定 、 、 ,通過以 作為約束,以 作為最小化目標(成本之和),Monge 問題可以轉化為線性規劃問題,其中矩陣 有表格
通過求解上述線性規劃問題的對偶形式(解為 ),Wasserstein 距離 可以計算為 。
上述解決方案的靈感來自於 Vincent Herrmann 的博客 [5]。有關更徹底的解釋,請參閱 [4] 。
輸入分布可以是經驗的,因此來自其值是函數的有效輸入的樣本,或者它們可以被視為廣義函數,在這種情況下,它們是位於指定值的狄拉克 delta 函數的加權和。
參考:
[1]“Wasserstein metric”、https://en.wikipedia.org/wiki/Wasserstein_metric
[2]Lili Weng,“什麽是 Wasserstein 距離?”,Lil’log,https://lilianweng.github.io/posts/2017-08-20-gan/#what-is-wasserstein-distance。
[3]Ramdas, Garcia, Cuturi “關於 Wasserstein 的兩個樣本檢驗和相關的非參數檢驗係列”(2015 年)。 arXiv:1509.02237 。
[4]佩雷、加布裏埃爾和馬可·庫圖裏。 “計算最優運輸。”經濟與統計研究中心工作論文2017-86(2017)。
[5]赫爾曼、文森特. “Wasserstein GAN 和 Kantorovich-Rubinstein 對偶性”。https://vincentherrmann.github.io/blog/wasserstein/.
例子:
>>> from scipy.stats import wasserstein_distance >>> wasserstein_distance([0, 1, 3], [5, 6, 8]) 5.0 >>> wasserstein_distance([0, 1], [0, 1], [3, 1], [2, 2]) 0.25 >>> wasserstein_distance([3.4, 3.9, 7.5, 7.8], [4.5, 1.4], ... [1.4, 0.9, 3.1, 7.2], [3.2, 3.5]) 4.0781331438047861
計算兩個三維樣本之間的 Wasserstein 距離,每個樣本都有兩個觀測值。
>>> wasserstein_distance([[0, 2, 3], [1, 2, 5]], [[3, 2, 3], [4, 2, 5]]) 3.0
分別使用三個和兩個加權觀測值計算兩個二維分布之間的 Wasserstein 距離。
>>> wasserstein_distance([[0, 2.75], [2, 209.3], [0, 0]], ... [[0.2, 0.322], [4.5, 25.1808]], ... [0.4, 5.2, 0.114], [0.8, 1.5]) 174.15840245217169
相關用法
- Python SciPy stats.wald用法及代碼示例
- Python SciPy stats.wilcoxon用法及代碼示例
- Python SciPy stats.weightedtau用法及代碼示例
- Python SciPy stats.wrapcauchy用法及代碼示例
- Python SciPy stats.weibull_min用法及代碼示例
- Python SciPy stats.weibull_max用法及代碼示例
- Python SciPy stats.wishart用法及代碼示例
- Python SciPy stats.anderson用法及代碼示例
- Python SciPy stats.iqr用法及代碼示例
- Python SciPy stats.genpareto用法及代碼示例
- Python SciPy stats.skewnorm用法及代碼示例
- Python SciPy stats.cosine用法及代碼示例
- Python SciPy stats.norminvgauss用法及代碼示例
- Python SciPy stats.directional_stats用法及代碼示例
- Python SciPy stats.invwishart用法及代碼示例
- Python SciPy stats.bartlett用法及代碼示例
- Python SciPy stats.levy_stable用法及代碼示例
- Python SciPy stats.page_trend_test用法及代碼示例
- Python SciPy stats.itemfreq用法及代碼示例
- Python SciPy stats.exponpow用法及代碼示例
- Python SciPy stats.gumbel_l用法及代碼示例
- Python SciPy stats.chisquare用法及代碼示例
- Python SciPy stats.semicircular用法及代碼示例
- Python SciPy stats.gzscore用法及代碼示例
- Python SciPy stats.gompertz用法及代碼示例
注:本文由純淨天空篩選整理自scipy.org大神的英文原創作品 scipy.stats.wasserstein_distance。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。