当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.Tensor.__getitem__用法及代码示例


用法

__getitem__(
    slice_spec, var=None
)

参数

  • tensor 一个 ops.Tensor 对象。
  • slice_spec 张量的参数。获取项目.
  • var 在变量切片赋值的情况下,切片的变量对象(即张量是这个变量的只读视图)。

返回

  • "tensor" 的适当切片,基于 "slice_spec"。

抛出

  • ValueError 如果切片范围为负大小。
  • TypeError 如果切片索引不是 int、切片、省略号、tf.newaxis 或标量 int32/int64 张量。

Tensor.getitem 的重载。

此操作从张量中提取指定区域。该表示法类似于 NumPy,但目前仅支持基本索引。这意味着当前不允许使用非标量张量作为输入。

一些有用的例子:

# Strip leading and trailing 2 elements
foo = tf.constant([1,2,3,4,5,6])
print(foo[2:-2].eval())  # => [3,4]

# Skip every other row and reverse the order of the columns
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[::2,::-1].eval())  # => [[3,2,1], [9,8,7]]

# Use scalar tensors as indices on both dimensions
print(foo[tf.constant(0), tf.constant(2)].eval())  # => 3

# Insert another dimension
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[tf.newaxis,:,:].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[:, tf.newaxis,:].eval()) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
print(foo[:,:, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]],
[[7],[8],[9]]]

# Ellipses (3 equivalent operations)
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[tf.newaxis,:,:].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[tf.newaxis, ...].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[tf.newaxis].eval())  # => [[[1,2,3], [4,5,6], [7,8,9]]]

# Masks
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[foo > 2].eval())  # => [3, 4, 5, 6, 7, 8, 9]

注意:

  • tf.newaxisNone,就像在 NumPy 中一样。
  • 隐式省略号放置在slice_spec 的末尾
  • 目前不支持 NumPy 高级索引。

API 中的用途:

此方法在 TensorFlow 的 API 中公开,因此库开发人员可以为 Tensor.getitem 注册调度,以允许它处理自定义复合张量和其他自定义对象。

API 符号不打算由用户直接调用,并且确实出现在 TensorFlow 生成的文档中。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.Tensor.__getitem__。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。