當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python mxnet.test_utils.rand_sparse_ndarray用法及代碼示例


用法:

mxnet.test_utils.rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, data_init=None, rsp_indices=None, modifier_func=None, shuffle_csr_indices=False, ctx=None)

參數

  • shape(list or tuple) -
  • stype(str) - 有效值:“csr” 或 “row_sparse”
  • density(float, optional) - 應該在 0 和 1 之間
  • distribution(str, optional) - 有效值:“uniform” or “powerlaw”
  • dtype(numpy.dtype, optional) - 默認值為無

返回

返回類型

CSRNDArray 或 RowSparseNDArray 類型的結果

生成一個隨機稀疏 ndarray。返回 ndarray、value(np) 和 indices(np)

例子

下麵是一個以 csr 作為 stype 的冪律分布示例。它使用形狀和密度計算 nnz。它用呈 index 增長的元素數量填充 ndarray。如果有足夠的unused_nnzs,則第 n+1 行的 nnz 將是第 n 行的兩倍。否則,剩餘的unused_nnzs 將用於第 n+1 行如果列數太小並且我們已經達到列大小,它將填充所有後續行中的所有後續列,直到達到所需的密度。

>>> csr_arr, _ = rand_sparse_ndarray(shape=(5, 16), stype="csr",
                                     density=0.50, distribution="powerlaw")
>>> indptr = csr_arr.indptr.asnumpy()
>>> indices = csr_arr.indices.asnumpy()
>>> data = csr_arr.data.asnumpy()
>>> row2nnz = len(data[indptr[1]:indptr[2]])
>>> row3nnz = len(data[indptr[2]:indptr[3]])
>>> assert(row3nnz == 2*row2nnz)
>>> row4nnz = len(data[indptr[3]:indptr[4]])
>>> assert(row4nnz == 2*row3nnz)

相關用法


注:本文由純淨天空篩選整理自apache.org大神的英文原創作品 mxnet.test_utils.rand_sparse_ndarray。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。