当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.raw_ops.BatchMatMulV2用法及代码示例


将两个张量的切片成批相乘。

用法

tf.raw_ops.BatchMatMulV2(
    x, y, adj_x=False, adj_y=False, name=None
)

参数

  • x 一个Tensor。必须是以下类型之一:bfloat16 , half , float32 , float64 , int16 , int32 , int64 , complex64 , complex128。具有形状 [..., r_x, c_x] 的二维或更高版本。
  • y 一个Tensor。必须与 x 具有相同的类型。二维或更高,形状为 [..., r_y, c_y]
  • adj_x 可选的 bool 。默认为 False 。如果 True ,则与 x 的切片相邻。默认为 False
  • adj_y 可选的 bool 。默认为 False 。如果 True ,则与 y 的切片相邻。默认为 False
  • name 操作的名称(可选)。

返回

  • 一个Tensor。具有与 x 相同的类型。

Tensor xy 的所有切片相乘(每个切片都可以视为批次的一个元素),并将各个结果排列在具有相同批次大小的单个输出张量中。通过将 adj_xadj_y 标志设置为 True ,默认情况下为 False ,每个单独的切片都可以在乘法之前可选地连接(连接矩阵意味着转置和共轭)。

输入张量 xy 是二维或更高的,形状为 [..., r_x, c_x][..., r_y, c_y]

输出张量为二维或更高,形状为 [..., r_o, c_o] ,其中:

r_o = c_x if adj_x else r_x
c_o = r_y if adj_y else c_y

它被计算为:

输出[...,:,:] = 矩阵(x[...,:,:]) * 矩阵(y[...,:,:])

注意: BatchMatMulV2支持批量维度的广播。更多关于广播这里.

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.raw_ops.BatchMatMulV2。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。