當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.raw_ops.BatchMatMulV3用法及代碼示例


將兩個張量的切片成批相乘。

用法

tf.raw_ops.BatchMatMulV3(
    x, y, Tout, adj_x=False, adj_y=False, name=None
)

參數

  • x 一個Tensor。必須是以下類型之一:bfloat16 , half , float32 , float64 , uint8 , int8 , int16 , int32 , int64 , complex64 , complex128。具有形狀 [..., r_x, c_x] 的二維或更高版本。
  • y 一個Tensor。必須是以下類型之一:bfloat16 , half , float32 , float64 , uint8 , int8 , int16 , int32 , int64 , complex64 , complex128。具有形狀 [..., r_y, c_y] 的二維或更高版本。
  • Tout A tf.dtypes.DType從:tf.bfloat16, tf.half, tf.float32, tf.float64, tf.int16, tf.int32, tf.int64, tf.complex64, tf.complex128.如果未指定,則 Tout 與輸入類型相同。
  • adj_x 可選的 bool 。默認為 False 。如果 True ,則與 x 的切片相鄰。默認為 False
  • adj_y 可選的 bool 。默認為 False 。如果 True ,則與 y 的切片相鄰。默認為 False
  • name 操作的名稱(可選)。

返回

  • Tensor 類型為 Tout

Tensor xy 的所有切片相乘(每個切片都可以視為批次的一個元素),並將各個結果排列在具有相同批次大小的單個輸出張量中。通過將 adj_xadj_y 標誌設置為 True ,默認情況下為 False ,每個單獨的切片都可以在乘法之前可選地連接(連接矩陣意味著轉置和共軛)。

輸入張量 xy 是二維或更高的,形狀為 [..., r_x, c_x][..., r_y, c_y]

輸出張量為二維或更高,形狀為 [..., r_o, c_o] ,其中:

r_o = c_x if adj_x else r_x
c_o = r_y if adj_y else c_y

它被計算為:

輸出[...,:,:] = 矩陣(x[...,:,:]) * 矩陣(y[...,:,:])

注意: BatchMatMulV3支持批量維度的廣播。更多關於廣播這裏.

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.raw_ops.BatchMatMulV3。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。