当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.contrib.symbol.index_copy用法及代码示例


用法:

mxnet.contrib.symbol.index_copy(old_tensor=None, index_vector=None, new_tensor=None, name=None, attr=None, out=None, **kwargs)

参数

  • old_tensor(Symbol) - 旧张量
  • index_vector(Symbol) - 索引向量
  • new_tensor(Symbol) - 要复制的新张量
  • name(string, optional.) - 结果符号的名称。

返回

结果符号。

返回类型

Symbol

new_tensor 的元素复制到 old_tensor 中。

此运算符通过按 index 中给出的顺序选择索引来复制元素。输出将是一个新张量,其中包含旧张量的其余元素和新张量的复制元素。例如,如果 index[i] == j ,则将 new_tensor 的第 i 行复制到输出的第 j 行。

index 必须是一个向量,并且它必须与 new_tensor0 th 维度具有相同的大小。此外,old_tensor 的第 0 维度必须是 new_tensor>=0 维度,否则将引发错误。

例子:

x = mx.nd.zeros((5,3))
t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
index = mx.nd.array([0,4,2])

mx.nd.contrib.index_copy(x, index, t)

[[1. 2. 3.]
 [0. 0. 0.]
 [7. 8. 9.]
 [0. 0. 0.]
 [4. 5. 6.]]
<NDArray 5x3 @cpu(0)>

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.contrib.symbol.index_copy。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。