当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.test_utils.rand_sparse_ndarray用法及代码示例


用法:

mxnet.test_utils.rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None, data_init=None, rsp_indices=None, modifier_func=None, shuffle_csr_indices=False, ctx=None)

参数

  • shape(list or tuple) -
  • stype(str) - 有效值:“csr” 或 “row_sparse”
  • density(float, optional) - 应该在 0 和 1 之间
  • distribution(str, optional) - 有效值:“uniform” or “powerlaw”
  • dtype(numpy.dtype, optional) - 默认值为无

返回

返回类型

CSRNDArray 或 RowSparseNDArray 类型的结果

生成一个随机稀疏 ndarray。返回 ndarray、value(np) 和 indices(np)

例子

下面是一个以 csr 作为 stype 的幂律分布示例。它使用形状和密度计算 nnz。它用呈 index 增长的元素数量填充 ndarray。如果有足够的unused_nnzs,则第 n+1 行的 nnz 将是第 n 行的两倍。否则,剩余的unused_nnzs 将用于第 n+1 行如果列数太小并且我们已经达到列大小,它将填充所有后续行中的所有后续列,直到达到所需的密度。

>>> csr_arr, _ = rand_sparse_ndarray(shape=(5, 16), stype="csr",
                                     density=0.50, distribution="powerlaw")
>>> indptr = csr_arr.indptr.asnumpy()
>>> indices = csr_arr.indices.asnumpy()
>>> data = csr_arr.data.asnumpy()
>>> row2nnz = len(data[indptr[1]:indptr[2]])
>>> row3nnz = len(data[indptr[2]:indptr[3]])
>>> assert(row3nnz == 2*row2nnz)
>>> row4nnz = len(data[indptr[3]:indptr[4]])
>>> assert(row4nnz == 2*row3nnz)

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.test_utils.rand_sparse_ndarray。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。