當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.where用法及代碼示例

返回非零元素的索引,或多路複用 xy

用法

tf.where(
    condition, x=None, y=None, name=None
)

參數

  • condition tf.Tensor dtype bool 或任何數字 dtype。當提供 xy 時,condition 必須具有 dtype bool
  • x 如果提供,一個與 y 類型相同的張量,並且具有可通過 conditiony 廣播的形狀。
  • y 如果提供,一個與 x 類型相同的張量,並且具有可通過 conditionx 廣播的形狀。
  • name 操作的名稱(可選)。

返回

  • 如果提供了 xy:與 xy 具有相同類型的 Tensor 以及從 condition , xy 廣播的形狀。否則,形狀為 [tf.math.count_nonzero(condition), tf.rank(condition)]Tensor

拋出

  • ValueError xy 中的一個恰好是 non-None 時,或者這些形狀並非都是可廣播的。

此操作有兩種模式:

  1. 返回非零元素的索引- 隻有當condition提供的結果是int64張量,其中每一行是非零元素的索引condition.結果的形狀是[tf.math.count_nonzero(condition), tf.rank(condition)].
  2. 複用xy- 當兩個xy提供結果的形狀為x,y, 和condition一起播出。結果取自x其中condition非零或y其中condition為零。

1. 返回非零元素的索引

注意:在這種模式下,condition 可以具有 bool 的 dtype 或任何數字 dtype。

如果未提供xy(均為無):

tf.where 將以形狀為 [n, d] 的二維張量的形式返回非零的 condition 的索引,其中 ncondition 中非零元素的數量(tf.count_nonzero(condition) ),dcondition ( tf.rank(condition) ) 的軸數。

索引以行優先順序輸出。 condition 可以有 dtypetf.bool 或任何數字 dtype

這裏 condition 是一個 1 軸 bool 張量,具有 2 個 True 值。結果的形狀為[2,1]

tf.where([True, False, False, True]).numpy()
array([[0],
       [3]])

這裏 condition 是一個 2 軸整數張量,具有 3 個非零值。結果的形狀為 [3, 2]

tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
array([[0, 0],
       [1, 0],
       [1, 2]])

這裏 condition 是一個 3 軸浮點張量,有 5 個非零值。輸出形狀為 [5, 3]

float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
                [[0,   0], [0,   0], [99,    0]]]
tf.where(float_tensor).numpy()
array([[0, 0, 0],
       [0, 1, 1],
       [0, 2, 0],
       [0, 2, 1],
       [1, 2, 0]])

這些索引與tf.sparse.SparseTensor 用來表示條件張量的索引相同:

sparse = tf.sparse.from_dense(float_tensor)
sparse.indices.numpy()
array([[0, 0, 0],
       [0, 1, 1],
       [0, 2, 0],
       [0, 2, 1],
       [1, 2, 0]])

如果實部或虛部中的任何一個非零,則認為複數非零:

tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
array([[1],
       [2],
       [3]])

2. 複用xy

注意:在這種模式下 condition 必須具有 bool 的數據類型。

如果還提供了 xy(兩者都具有非 None 值),則 condition 張量充當掩碼,選擇是否應從 x 中獲取輸出中的相應元素/行(如果元素在 conditionTrue )或 y (如果它是 False )。

結果的形狀是通過將 condition , xy 的形狀一起廣播形成的。

當所有三個輸入具有相同的大小時,每個都按元素處理。

tf.where([True, False, False, True],
         [1, 2, 3, 4],
         [100, 200, 300, 400]).numpy()
array([  1, 200, 300,   4], dtype=int32)

廣播有兩個主要規則:

  1. 如果張量的軸比其他張量少,則將長度為 1 的軸添加到形狀的左側。
  2. 長度為 1 的軸被拉伸以匹配其他張量的對應軸。

一個長度為 1 的向量被拉伸以匹配其他向量:

tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
array([  1, 100, 100,   4], dtype=int32)

擴展標量以匹配其他參數:

tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
array([[  1, 100], [100,   4]], dtype=int32)
tf.where([[True, False], [False, True]], 1, 100).numpy()
array([[  1, 100], [100,   1]], dtype=int32)

標量 condition 返回完整的 xy 張量,並應用了廣播。

tf.where(True, [1, 2, 3, 4], 100).numpy()
array([1, 2, 3, 4], dtype=int32)
tf.where(False, [1, 2, 3, 4], 100).numpy()
array([100, 100, 100, 100], dtype=int32)

對於廣播的一個重要示例,這裏 condition 的形狀為 [3] , x 的形狀為 [3,3] ,而 y 的形狀為 [3,1] 。廣播首先將 condition 的形狀擴展為 [1,3] 。最終的廣播形狀是 [3,3]condition 將從 xy 中選擇列。由於 y 隻有一列,因此 y 中的所有列都將相同。

tf.where([True, False, True],
         x=[[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]],
         y=[[100],
            [200],
            [300]]
).numpy()
array([[ 1, 100, 3],
       [ 4, 200, 6],
       [ 7, 300, 9]], dtype=int32)

請注意,如果 tf.where 的任一分支的梯度生成 NaN ,則整個 tf.where 的梯度將為 NaN 。這是因為出於性能原因,tf.where 的梯度計算結合了兩個分支。

一種解決方法是使用內部 tf.where 來確保函數沒有漸近線,並通過將危險輸入替換為安全輸入來避免計算梯度為 NaN 的值。

取而代之的是,

x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
  tape.watch(x)
  y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)

盡管從未使用過 1. / x 值,但當 x = 0 時,它的梯度是 NaN 。相反,我們應該用另一個tf.where來保護它

x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
  tape.watch(x)
  safe_x = tf.where(tf.equal(x, 0.), 1., x)
  y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)

也可以看看:

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.where。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。