當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.IndexedSlices用法及代碼示例


給定索引處的一組張量切片的稀疏表示。

用法

tf.IndexedSlices(
    values, indices, dense_shape=None
)

屬性

  • dense_shape 一維 Tensor 包含相應密集張量的形狀。
  • device 將在其上生成 values 的設備的名稱,或 None
  • dtype 該張量中元素的DType
  • graph Graph 包含值、索引和形狀張量。
  • indices 包含切片索引的一維Tensor
  • name IndexedSlices 的名稱。
  • op 生成values 作為輸出的Operation
  • shape 獲取表示密集張量形狀的tf.TensorShape
  • values 包含切片值的 Tensor

此類是一對 Tensor 對象的簡單包裝器:

  • values :具有形狀 [D0, D1, ..., Dn] 的任何 dtype 的 Tensor
  • indices :一維整數 Tensor 形狀為 [D0]

IndexedSlices 通常用於表示形狀為 [LARGE0, D1, .. , DN] 的較大張量 dense 的子集,其中 LARGE0 >> D0indices 中的值是從較大張量中提取的切片的第一維中的索引。

IndexedSlices slices 表示的稠密張量 dense 具有

dense[slices.indices[i],:,:,:, ...] = slices.values[i,:,:,:, ...]

IndexedSlices 類主要用於定義具有稀疏梯度的操作的梯度(例如 tf.gather )。

v = tf.Variable([[0.,1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8]])
with tf.GradientTape() as tape:
  r = tf.gather(v, [1,3])
index_slices = tape.gradient(r,v)
index_slices
<...IndexedSlices object ...>
index_slices.indices.numpy()
array([1, 3], dtype=int32)
index_slices.values.numpy()
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

將此表示與使用多維索引和標量值的 tf.sparse.SparseTensor 進行對比。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.IndexedSlices。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。