當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.einsum用法及代碼示例

指定 index 和外積上的張量收縮。

用法

tf.einsum(
    equation, *inputs, **kwargs
)

參數

  • equation 說明收縮的 str,格式與 numpy.einsum 相同。
  • *inputs 合同的輸入(每個都是 Tensor ),其形狀應與 equation 一致。
  • **kwargs
    • 優化:使用opt_einsum查找收縮路徑的優化策略。必須是 'greedy', 'optimal'、'branch-2'、'branch-all' 或 'auto'。 (可選,默認值:'greedy')。
    • name:操作的名稱(可選)。

返回

  • 收縮的 Tensor ,形狀由 equation 確定。

拋出

  • ValueError 如果
    • equation 的格式不正確,
    • 輸入數量或其形狀與 equation 不一致。

Einsum 允許通過定義其元素計算來定義張量。此計算由 equation 定義,這是一種基於愛因斯坦求和的簡寫形式。例如,考慮將兩個矩陣 A 和 B 相乘以形成矩陣 C。C 的元素由下式給出:

或者

C[i,k] = sum_j A[i,j] * B[j,k]

對應的 einsum equation 為:

ij,jk->ik

通常,要將逐元素方程轉換為 equation 字符串,請使用以下過程(括號中提供的矩陣乘法示例的中間字符串):

  1. 刪除變量名、括號和逗號,(ik = sum_j ij * jk)
  2. 用","替換"*",(ik = sum_j ij , jk)
  3. 刪除求和符號,和 (ik = ij, jk)
  4. 將輸出向右移動,同時將 "=" 替換為 "->"。 (ij,jk->ik)

注意:如果未指定輸出索引,則對重複索引求和。所以 ij,jk->ik 可以簡化為 ij,jk

許多常見的操作都可以用這種方式表示。例如:

矩陣乘法

m0 = tf.random.normal(shape=[2, 3])
m1 = tf.random.normal(shape=[3, 5])
e = tf.einsum('ij,jk->ik', m0, m1)
# output[i,k] = sum_j m0[i,j] * m1[j, k]
print(e.shape)
(2, 5)

如果未指定輸出索引,則對重複索引求和。

e = tf.einsum('ij,jk', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]
print(e.shape)
(2, 5)

點積

u = tf.random.normal(shape=[5])
v = tf.random.normal(shape=[5])
e = tf.einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]
print(e.shape)
()

外積

u = tf.random.normal(shape=[3])
v = tf.random.normal(shape=[5])
e = tf.einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]
print(e.shape)
(3, 5)

轉置

m = tf.ones(2,3)
e = tf.einsum('ij->ji', m0)  # output[j,i] = m0[i,j]
print(e.shape)
(3, 2)

診斷

m = tf.reshape(tf.range(9), [3,3])
diag = tf.einsum('ii->i', m)
print(diag.shape)
(3,)

追蹤

# Repeated indices are summed.
trace = tf.einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]
assert trace == sum(diag)
print(trace.shape)
()

批量矩陣乘法

s = tf.random.normal(shape=[7,5,3])
t = tf.random.normal(shape=[7,3,2])
e = tf.einsum('bij,bjk->bik', s, t)
# output[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
print(e.shape)
(7, 5, 2)

此方法不支持named-axes 上的廣播。具有匹配標簽的所有軸應具有相同的長度。如果您有長度為 1 的軸,請使用 tf.squeseze tf.reshape 來消除它們。

要編寫與輸入中的索引數量無關的代碼,請使用省略號。省略號是“此處適合的任何其他索引”的占位符。

例如,要執行 NumPy 樣式的broadcasting-batch-matrix 乘法,其中矩陣乘法作用於輸入的最後兩個軸,請使用:

s = tf.random.normal(shape=[11, 7, 5, 3])
t = tf.random.normal(shape=[11, 7, 3, 2])
e =  tf.einsum('...ij,...jk->...ik', s, t)
print(e.shape)
(11, 7, 5, 2)

Einsum 將在省略號覆蓋的軸上廣播。

s = tf.random.normal(shape=[11, 1, 5, 3])
t = tf.random.normal(shape=[1, 7, 3, 2])
e =  tf.einsum('...ij,...jk->...ik', s, t)
print(e.shape)
(11, 7, 5, 2)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.einsum。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。