當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.raw_ops.SparseMatrixSparseMatMul用法及代碼示例


Sparse-matrix-multiplies 兩個 CSR 矩陣 ab

用法

tf.raw_ops.SparseMatrixSparseMatMul(
    a, b, type, transpose_a=False, transpose_b=False, adjoint_a=False,
    adjoint_b=False, name=None
)

參數

  • a Tensor 類型為 variant 。一個 CSRSparseMatrix。
  • b Tensor 類型為 variant 。一個 CSRSparseMatrix。
  • type tf.DType 來自:tf.float32, tf.float64, tf.complex64, tf.complex128
  • transpose_a 可選的 bool 。默認為 False 。指示是否應轉置a
  • transpose_b 可選的 bool 。默認為 False 。指示是否應轉置b
  • adjoint_a 可選的 bool 。默認為 False 。指示a 是否應為conjugate-transposed。
  • adjoint_b 可選的 bool 。默認為 False 。指示b 是否應為conjugate-transposed。
  • name 操作的名稱(可選)。

返回

  • Tensor 類型為 variant

執行稀疏矩陣 a 與稀疏矩陣 b 的矩陣乘法;返回稀疏矩陣 a * b ,除非 ab 被轉置或鄰接。

每個矩陣可以根據布爾參數transpose_a , adjoint_a , transpose_badjoint_b 進行轉置或鄰接(共軛和轉置)。最多 transpose_aadjoint_a 之一可能為 True。類似地,最多transpose_badjoint_b之一可能為True。

輸入必須具有兼容的形狀。也就是說,a 的內部尺寸必須等於 b 的外部尺寸。此要求根據ab 是轉置還是鄰接進行調整。

type 參數表示矩陣元素的類型。 ab 必須具有相同的類型。支持的類型是:float32 , float64 , complex64complex128

ab 必須具有相同的等級。不支持廣播。如果它們的等級為 3,則 ab 中的每批 2D CSRSparseMatrices 必須具有相同的密集形狀。

稀疏矩陣乘積可能有數字(非結構)零。

零。

使用示例:

from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops

    a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
    a_values = np.array([1.0, 5.0, -1.0, -2.0], np.float32)
    a_dense_shape = [4, 5]

    b_indices = np.array([[0, 0], [3, 0], [3, 1]])
    b_values = np.array([2.0, 7.0, 8.0], np.float32)
    b_dense_shape = [5, 3]

    with tf.Session() as sess:
      # Define (COO format) Sparse Tensors over Numpy arrays
      a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape)
      b_st = tf.sparse.SparseTensor(b_indices, b_values, b_dense_shape)

      # Convert SparseTensors to CSR SparseMatrix
      a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          a_st.indices, a_st.values, a_st.dense_shape)
      b_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          b_st.indices, b_st.values, b_st.dense_shape)

      # Compute the CSR SparseMatrix matrix multiplication
      c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
          a=a_sm, b=b_sm, type=tf.float32)

      # Convert the CSR SparseMatrix product to a dense Tensor
      c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          c_sm, tf.float32)
      # Evaluate the dense Tensor value
      c_sm_dense_value = sess.run(c_sm_dense)

c_sm_dense_value 存儲密集矩陣乘積:

[[  2.   0.   0.]
     [  0.   0.   0.]
     [ 35.  40.   0.]
     [ -4.   0.   0.]]

一個:A CSRSparseMatrix 。 b:A CSRSparseMatrixa 具有相同類型和等級。 type: ab 的類型。 transpose_a:如果為真,a 在乘法前轉置。 transpose_b:如果為真,b 在乘法前轉置。 adjoint_a:如果為真,a 在乘法前伴隨。 adjoint_b:如果為真,b 在乘法之前伴隨。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.raw_ops.SparseMatrixSparseMatMul。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。