当前位置: 首页>>编程示例 >>用法及示例精选 >>正文


Python tf.einsum用法及代码示例

指定 index 和外积上的张量收缩。

用法

tf.einsum(
    equation, *inputs, **kwargs
)

参数

  • equation 说明收缩的 str,格式与 numpy.einsum 相同。
  • *inputs 合同的输入(每个都是 Tensor ),其形状应与 equation 一致。
  • **kwargs
    • 优化:使用opt_einsum查找收缩路径的优化策略。必须是 'greedy', 'optimal'、'branch-2'、'branch-all' 或 'auto'。 (可选,默认值:'greedy')。
    • name:操作的名称(可选)。

返回

  • 收缩的 Tensor ,形状由 equation 确定。

抛出

  • ValueError 如果
    • equation 的格式不正确,
    • 输入数量或其形状与 equation 不一致。

Einsum 允许通过定义其元素计算来定义张量。此计算由 equation 定义,这是一种基于爱因斯坦求和的简写形式。例如,考虑将两个矩阵 A 和 B 相乘以形成矩阵 C。C 的元素由下式给出:

或者

C[i,k] = sum_j A[i,j] * B[j,k]

对应的 einsum equation 为:

ij,jk->ik

通常,要将逐元素方程转换为 equation 字符串,请使用以下过程(括号中提供的矩阵乘法示例的中间字符串):

  1. 删除变量名、括号和逗号,(ik = sum_j ij * jk)
  2. 用","替换"*",(ik = sum_j ij , jk)
  3. 删除求和符号,和 (ik = ij, jk)
  4. 将输出向右移动,同时将 "=" 替换为 "->"。 (ij,jk->ik)

注意:如果未指定输出索引,则对重复索引求和。所以 ij,jk->ik 可以简化为 ij,jk

许多常见的操作都可以用这种方式表示。例如:

矩阵乘法

m0 = tf.random.normal(shape=[2, 3])
m1 = tf.random.normal(shape=[3, 5])
e = tf.einsum('ij,jk->ik', m0, m1)
# output[i,k] = sum_j m0[i,j] * m1[j, k]
print(e.shape)
(2, 5)

如果未指定输出索引,则对重复索引求和。

e = tf.einsum('ij,jk', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]
print(e.shape)
(2, 5)

点积

u = tf.random.normal(shape=[5])
v = tf.random.normal(shape=[5])
e = tf.einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]
print(e.shape)
()

外积

u = tf.random.normal(shape=[3])
v = tf.random.normal(shape=[5])
e = tf.einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]
print(e.shape)
(3, 5)

转置

m = tf.ones(2,3)
e = tf.einsum('ij->ji', m0)  # output[j,i] = m0[i,j]
print(e.shape)
(3, 2)

诊断

m = tf.reshape(tf.range(9), [3,3])
diag = tf.einsum('ii->i', m)
print(diag.shape)
(3,)

追踪

# Repeated indices are summed.
trace = tf.einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]
assert trace == sum(diag)
print(trace.shape)
()

批量矩阵乘法

s = tf.random.normal(shape=[7,5,3])
t = tf.random.normal(shape=[7,3,2])
e = tf.einsum('bij,bjk->bik', s, t)
# output[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
print(e.shape)
(7, 5, 2)

此方法不支持named-axes 上的广播。具有匹配标签的所有轴应具有相同的长度。如果您有长度为 1 的轴,请使用 tf.squeseze tf.reshape 来消除它们。

要编写与输入中的索引数量无关的代码,请使用省略号。省略号是“此处适合的任何其他索引”的占位符。

例如,要执行 NumPy 样式的broadcasting-batch-matrix 乘法,其中矩阵乘法作用于输入的最后两个轴,请使用:

s = tf.random.normal(shape=[11, 7, 5, 3])
t = tf.random.normal(shape=[11, 7, 3, 2])
e =  tf.einsum('...ij,...jk->...ik', s, t)
print(e.shape)
(11, 7, 5, 2)

Einsum 将在省略号覆盖的轴上广播。

s = tf.random.normal(shape=[11, 1, 5, 3])
t = tf.random.normal(shape=[1, 7, 3, 2])
e =  tf.einsum('...ij,...jk->...ik', s, t)
print(e.shape)
(11, 7, 5, 2)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.einsum。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。