当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.raw_ops.SparseMatrixSparseMatMul用法及代码示例


Sparse-matrix-multiplies 两个 CSR 矩阵 ab

用法

tf.raw_ops.SparseMatrixSparseMatMul(
    a, b, type, transpose_a=False, transpose_b=False, adjoint_a=False,
    adjoint_b=False, name=None
)

参数

  • a Tensor 类型为 variant 。一个 CSRSparseMatrix。
  • b Tensor 类型为 variant 。一个 CSRSparseMatrix。
  • type tf.DType 来自:tf.float32, tf.float64, tf.complex64, tf.complex128
  • transpose_a 可选的 bool 。默认为 False 。指示是否应转置a
  • transpose_b 可选的 bool 。默认为 False 。指示是否应转置b
  • adjoint_a 可选的 bool 。默认为 False 。指示a 是否应为conjugate-transposed。
  • adjoint_b 可选的 bool 。默认为 False 。指示b 是否应为conjugate-transposed。
  • name 操作的名称(可选)。

返回

  • Tensor 类型为 variant

执行稀疏矩阵 a 与稀疏矩阵 b 的矩阵乘法;返回稀疏矩阵 a * b ,除非 ab 被转置或邻接。

每个矩阵可以根据布尔参数transpose_a , adjoint_a , transpose_badjoint_b 进行转置或邻接(共轭和转置)。最多 transpose_aadjoint_a 之一可能为 True。类似地,最多transpose_badjoint_b之一可能为True。

输入必须具有兼容的形状。也就是说,a 的内部尺寸必须等于 b 的外部尺寸。此要求根据ab 是转置还是邻接进行调整。

type 参数表示矩阵元素的类型。 ab 必须具有相同的类型。支持的类型是:float32 , float64 , complex64complex128

ab 必须具有相同的等级。不支持广播。如果它们的等级为 3,则 ab 中的每批 2D CSRSparseMatrices 必须具有相同的密集形状。

稀疏矩阵乘积可能有数字(非结构)零。

零。

使用示例:

from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops

    a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
    a_values = np.array([1.0, 5.0, -1.0, -2.0], np.float32)
    a_dense_shape = [4, 5]

    b_indices = np.array([[0, 0], [3, 0], [3, 1]])
    b_values = np.array([2.0, 7.0, 8.0], np.float32)
    b_dense_shape = [5, 3]

    with tf.Session() as sess:
      # Define (COO format) Sparse Tensors over Numpy arrays
      a_st = tf.sparse.SparseTensor(a_indices, a_values, a_dense_shape)
      b_st = tf.sparse.SparseTensor(b_indices, b_values, b_dense_shape)

      # Convert SparseTensors to CSR SparseMatrix
      a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          a_st.indices, a_st.values, a_st.dense_shape)
      b_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
          b_st.indices, b_st.values, b_st.dense_shape)

      # Compute the CSR SparseMatrix matrix multiplication
      c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
          a=a_sm, b=b_sm, type=tf.float32)

      # Convert the CSR SparseMatrix product to a dense Tensor
      c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
          c_sm, tf.float32)
      # Evaluate the dense Tensor value
      c_sm_dense_value = sess.run(c_sm_dense)

c_sm_dense_value 存储密集矩阵乘积:

[[  2.   0.   0.]
     [  0.   0.   0.]
     [ 35.  40.   0.]
     [ -4.   0.   0.]]

一个:A CSRSparseMatrix 。 b:A CSRSparseMatrixa 具有相同类型和等级。 type: ab 的类型。 transpose_a:如果为真,a 在乘法前转置。 transpose_b:如果为真,b 在乘法前转置。 adjoint_a:如果为真,a 在乘法前伴随。 adjoint_b:如果为真,b 在乘法之前伴随。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.raw_ops.SparseMatrixSparseMatMul。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。