当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python mxnet.symbol.take用法及代码示例


用法:

mxnet.symbol.take(a=None, indices=None, axis=_Null, mode=_Null, name=None, attr=None, out=None, **kwargs)

参数

  • a(Symbol) - 输入数组。
  • indices(Symbol) - 要提取的值的索引。
  • axis(int, optional, default='0') - 要取的输入数组的轴。对于秩为 r 的输入张量,它可以在 [-r, r-1] 的范围内
  • mode({'clip', 'raise', 'wrap'},optional, default='clip') - 指定越界索引如何处理。默认为“clip”。 “clip” 表示剪辑到范围。因此,如果提到的所有索引都太大,则它们将替换为指向轴上最后一个元素的索引。 “wrap” 表示环绕。 “raise” 表示当索引超出范围时引发错误。
  • name(string, optional.) - 结果符号的名称。

返回

结果符号。

返回类型

Symbol

沿给定轴从输入数组中获取元素。

此函数使用提供的索引沿特定轴对输入数组进行切片。

给定等级 r >= 1 的数据张量和等级 q 的索引张量,收集由索引索引的数据轴维度的条目(默认情况下 outer-most 1 作为轴 = 0),并将它们连接到等级的输出张量中q + (r - 1)。

例子:

x = [4.  5.  6.]

// Trivial case, take the second element along the first axis.

take(x, [1]) = [ 5. ]

// The other trivial case, axis=-1, take the third element along the first axis

take(x, [3], axis=-1, mode='clip') = [ 6. ]

x = [[ 1.,  2.],
     [ 3.,  4.],
     [ 5.,  6.]]

// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0

take(x, [[0,1],[1,2]]) = [[[ 1.,  2.],
                           [ 3.,  4.]],

                          [[ 3.,  4.],
                           [ 5.,  6.]]]

// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around).
// Along axis 1

take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1.  2.]
                                                     [ 2.  1.]]

                                                    [[ 3.  4.]
                                                     [ 4.  3.]]

                                                    [[ 5.  6.]
                                                     [ 6.  5.]]]

take 输出的存储类型取决于输入存储类型:

  • take(default, default) = default

  • take(csr, default, axis=0) = csr

相关用法


注:本文由纯净天空筛选整理自apache.org大神的英文原创作品 mxnet.symbol.take。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。