当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.IndexedSlices用法及代码示例


给定索引处的一组张量切片的稀疏表示。

用法

tf.IndexedSlices(
    values, indices, dense_shape=None
)

属性

  • dense_shape 一维 Tensor 包含相应密集张量的形状。
  • device 将在其上生成 values 的设备的名称,或 None
  • dtype 该张量中元素的DType
  • graph Graph 包含值、索引和形状张量。
  • indices 包含切片索引的一维Tensor
  • name IndexedSlices 的名称。
  • op 生成values 作为输出的Operation
  • shape 获取表示密集张量形状的tf.TensorShape
  • values 包含切片值的 Tensor

此类是一对 Tensor 对象的简单包装器:

  • values :具有形状 [D0, D1, ..., Dn] 的任何 dtype 的 Tensor
  • indices :一维整数 Tensor 形状为 [D0]

IndexedSlices 通常用于表示形状为 [LARGE0, D1, .. , DN] 的较大张量 dense 的子集,其中 LARGE0 >> D0indices 中的值是从较大张量中提取的切片的第一维中的索引。

IndexedSlices slices 表示的稠密张量 dense 具有

dense[slices.indices[i],:,:,:, ...] = slices.values[i,:,:,:, ...]

IndexedSlices 类主要用于定义具有稀疏梯度的操作的梯度(例如 tf.gather )。

v = tf.Variable([[0.,1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8]])
with tf.GradientTape() as tape:
  r = tf.gather(v, [1,3])
index_slices = tape.gradient(r,v)
index_slices
<...IndexedSlices object ...>
index_slices.indices.numpy()
array([1, 3], dtype=int32)
index_slices.values.numpy()
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

将此表示与使用多维索引和标量值的 tf.sparse.SparseTensor 进行对比。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.IndexedSlices。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。