狄利克雷分布。
繼承自:Distribution
用法
tf.compat.v1.distributions.Dirichlet(
concentration, validate_args=False, allow_nan_stats=True,
name='Dirichlet'
)
參數
-
concentration
正浮點Tensor
表示類的平均出現次數;又名"alpha"。暗示self.dtype
和self.batch_shape
,self.event_shape
,即如果concentration.shape = [N1, N2, ..., Nm, k]
則batch_shape = [N1, N2, ..., Nm]
和event_shape = [k]
。 -
validate_args
Pythonbool
,默認False
。盡管可能會降低運行時性能,但檢查True
分發參數的有效性時。當False
無效輸入可能會默默呈現不正確的輸出。 -
allow_nan_stats
Pythonbool
,默認True
。當True
時,統計信息(例如,均值、眾數、方差)使用值“NaN
”來指示結果未定義。當False
時,如果一個或多個統計數據的批處理成員未定義,則會引發異常。 -
name
Pythonstr
名稱以此類創建的 Ops 為前綴。
屬性
-
allow_nan_stats
Pythonbool
說明未定義統計信息時的行為。統計數據在有意義時返回 +/- 無窮大。例如,柯西分布的方差是無窮大的。但是,有時統計數據是未定義的,例如,如果分布的 pdf 在分布的支持範圍內沒有達到最大值,則模式是未定義的。如果均值未定義,則根據定義,方差未定義。例如: df = 1 的 Student's T 的平均值是未定義的(沒有明確的方式說它是 + 或 - 無窮大),因此方差 = E[(X - mean)**2] 也是未定義的。
-
batch_shape
來自單個事件索引的單個樣本的形狀作為TensorShape
.可能部分定義或未知。
批次維度是該分布的獨立、不同參數化的索引。
-
concentration
濃度參數;該坐標的預期計數。 -
dtype
Tensor
的DType
由此Distribution
處理。 -
event_shape
單個批次的單個樣品的形狀作為TensorShape
.可能部分定義或未知。
-
name
此Distribution
創建的所有操作前的名稱。 -
parameters
用於實例化此Distribution
的參數字典。 -
reparameterization_type
說明如何重新參數化分布中的樣本。目前這是靜態實例
distributions.FULLY_REPARAMETERIZED
或distributions.NOT_REPARAMETERIZED
之一。 -
total_concentration
last dim of 濃度參數的總和。 -
validate_args
Pythonbool
表示啟用了可能昂貴的檢查。
狄利克雷分布是在 (k-1)-單純形上定義的,使用正的、長度為 k
的向量 concentration
(k > 1
)。 Dirichlet 與 k = 2
時的 Beta 分布相同。
數學細節
Dirichlet 是開放的 (k-1)
-simplex 上的分布,即
S^{k-1} = { (x_0, ..., x_{k-1}) in R^k:sum_j x_j = 1 and all_j x_j > 0 }.
概率密度函數 (pdf) 是,
pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)
其中:
x in S^{k-1}
,即(k-1)
-simplex,concentration = alpha = [alpha_0, ..., alpha_{k-1}]
,alpha_j > 0
,Z
是歸一化常數,也就是多元 beta 函數,並且,Gamma
是伽瑪函數。
concentration
表示類出現的平均總計數,即
concentration = alpha = mean * total_concentration
其中 S^{k-1}
和 total_concentration
中的 mean
是一個正實數,表示平均總計數。
分布參數在所有函數中自動廣播;有關詳細信息,請參見示例。
警告:由於有限的精度,樣本的某些分量可能為零。當某些濃度非常小時,這種情況會更頻繁地發生。確保在計算密度之前將樣本四舍五入到np.finfo(dtype).tiny
。
該分布的樣本被重新參數化(路徑可微)。導數是使用 (Figurnov et al., 2018) 中說明的方法計算的。
例子
import tensorflow_probability as tfp
tfd = tfp.distributions
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape:[4, 5, 3]
# x has one sample, one batch, three classes:
x = [.2, .3, .5] # shape:[3]
dist.prob(x) # shape:[]
# x has two samples from one batch:
x = [[.1, .4, .5],
[.2, .3, .5]]
dist.prob(x) # shape:[2]
# alpha will be broadcast to shape [5, 7, 3] to match x.
x = [[...]] # shape:[5, 7, 3]
dist.prob(x) # shape:[5, 7]
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape:[2, 3]
dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape:[4, 5, 2, 3]
x = [.2, .3, .5]
# x will be broadcast as [[.2, .3, .5],
# [.2, .3, .5]],
# thus matching batch_shape [2, 3].
dist.prob(x) # shape:[2]
計算樣本的梯度 w.r.t.參數:
alpha = tf.constant([1.0, 2.0, 3.0])
dist = tfd.Dirichlet(alpha)
samples = dist.sample(5) # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, alpha)
參考:
隱式重新參數化梯度:Figurnov 等人,2018 (pdf)
相關用法
- Python tf.compat.v1.distributions.Dirichlet.covariance用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.stddev用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.kl_divergence用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.cross_entropy用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.log_survival_function用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.survival_function用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.quantile用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.cdf用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.survival_function用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.log_cdf用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.log_survival_function用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.log_cdf用法及代碼示例
- Python tf.compat.v1.distributions.Dirichlet.variance用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.covariance用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.cdf用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.variance用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.kl_divergence用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.stddev用法及代碼示例
- Python tf.compat.v1.distributions.DirichletMultinomial.cross_entropy用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.distributions.Dirichlet。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。