返回非零元素的索引,或多路複用 x
和 y
。
用法
tf.where(
condition, x=None, y=None, name=None
)
參數
-
condition
tf.Tensor
dtype bool 或任何數字 dtype。當提供x
和y
時,condition
必須具有 dtypebool
。 -
x
如果提供,一個與y
類型相同的張量,並且具有可通過condition
和y
廣播的形狀。 -
y
如果提供,一個與x
類型相同的張量,並且具有可通過condition
和x
廣播的形狀。 -
name
操作的名稱(可選)。
返回
-
如果提供了
x
和y
:與x
和y
具有相同類型的Tensor
以及從condition
,x
和y
廣播的形狀。否則,形狀為[tf.math.count_nonzero(condition), tf.rank(condition)]
的Tensor
。
拋出
-
ValueError
當x
或y
中的一個恰好是 non-None 時,或者這些形狀並非都是可廣播的。
此操作有兩種模式:
- 返回非零元素的索引- 隻有當
condition
提供的結果是int64
張量,其中每一行是非零元素的索引condition
.結果的形狀是[tf.math.count_nonzero(condition), tf.rank(condition)]
. - 複用
x
和y
- 當兩個x
和y
提供結果的形狀為x
,y
, 和condition
一起播出。結果取自x
其中condition
非零或y
其中condition
為零。
1. 返回非零元素的索引
注意:在這種模式下,condition
可以具有 bool
的 dtype 或任何數字 dtype。
如果未提供x
和y
(均為無):
tf.where
將以形狀為 [n, d]
的二維張量的形式返回非零的 condition
的索引,其中 n
是 condition
中非零元素的數量(tf.count_nonzero(condition)
),d
是condition
( tf.rank(condition)
) 的軸數。
索引以行優先順序輸出。 condition
可以有 dtype
或 tf.bool
或任何數字 dtype
。
這裏 condition
是一個 1 軸 bool
張量,具有 2 個 True
值。結果的形狀為[2,1]
tf.where([True, False, False, True]).numpy()
array([[0],
[3]])
這裏 condition
是一個 2 軸整數張量,具有 3 個非零值。結果的形狀為 [3, 2]
。
tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
array([[0, 0],
[1, 0],
[1, 2]])
這裏 condition
是一個 3 軸浮點張量,有 5 個非零值。輸出形狀為 [5, 3]
。
float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
[[0, 0], [0, 0], [99, 0]]]
tf.where(float_tensor).numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
這些索引與tf.sparse.SparseTensor
用來表示條件張量的索引相同:
sparse = tf.sparse.from_dense(float_tensor)
sparse.indices.numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
如果實部或虛部中的任何一個非零,則認為複數非零:
tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
array([[1],
[2],
[3]])
2. 複用x
和y
注意:在這種模式下 condition
必須具有 bool
的數據類型。
如果還提供了 x
和 y
(兩者都具有非 None 值),則 condition
張量充當掩碼,選擇是否應從 x
中獲取輸出中的相應元素/行(如果元素在 condition
是 True
)或 y
(如果它是 False
)。
結果的形狀是通過將 condition
, x
和 y
的形狀一起廣播形成的。
當所有三個輸入具有相同的大小時,每個都按元素處理。
tf.where([True, False, False, True],
[1, 2, 3, 4],
[100, 200, 300, 400]).numpy()
array([ 1, 200, 300, 4], dtype=int32)
廣播有兩個主要規則:
- 如果張量的軸比其他張量少,則將長度為 1 的軸添加到形狀的左側。
- 長度為 1 的軸被拉伸以匹配其他張量的對應軸。
一個長度為 1 的向量被拉伸以匹配其他向量:
tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
array([ 1, 100, 100, 4], dtype=int32)
擴展標量以匹配其他參數:
tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
array([[ 1, 100], [100, 4]], dtype=int32)
tf.where([[True, False], [False, True]], 1, 100).numpy()
array([[ 1, 100], [100, 1]], dtype=int32)
標量 condition
返回完整的 x
或 y
張量,並應用了廣播。
tf.where(True, [1, 2, 3, 4], 100).numpy()
array([1, 2, 3, 4], dtype=int32)
tf.where(False, [1, 2, 3, 4], 100).numpy()
array([100, 100, 100, 100], dtype=int32)
對於廣播的一個重要示例,這裏 condition
的形狀為 [3]
, x
的形狀為 [3,3]
,而 y
的形狀為 [3,1]
。廣播首先將 condition
的形狀擴展為 [1,3]
。最終的廣播形狀是 [3,3]
。 condition
將從 x
和 y
中選擇列。由於 y
隻有一列,因此 y
中的所有列都將相同。
tf.where([True, False, True],
x=[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
y=[[100],
[200],
[300]]
).numpy()
array([[ 1, 100, 3],
[ 4, 200, 6],
[ 7, 300, 9]], dtype=int32)
請注意,如果 tf.where
的任一分支的梯度生成 NaN
,則整個 tf.where
的梯度將為 NaN
。這是因為出於性能原因,tf.where
的梯度計算結合了兩個分支。
一種解決方法是使用內部 tf.where
來確保函數沒有漸近線,並通過將危險輸入替換為安全輸入來避免計算梯度為 NaN
的值。
取而代之的是,
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)
盡管從未使用過 1. / x
值,但當 x = 0
時,它的梯度是 NaN
。相反,我們應該用另一個tf.where
來保護它
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
safe_x = tf.where(tf.equal(x, 0.), 1., x)
y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)
也可以看看:
tf.sparse
- 第一種形式的tf.where
返回的索引在tf.sparse.SparseTensor
對象中很有用。tf.gather_nd
、tf.scatter_nd
和相關操作 - 給定從tf.where
返回的索引列表,可以使用scatter
和gather
係列操作在這些索引處獲取值或插入值。tf.strings.length
-tf.string
不是condition
允許的 dtype。請改用字符串長度。
相關用法
- Python tf.while_loop用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代碼示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代碼示例
- Python tf.raw_ops.TPUReplicatedInput用法及代碼示例
- Python tf.raw_ops.Bitcast用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代碼示例
- Python tf.math.special.fresnel_cos用法及代碼示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代碼示例
- Python tf.compat.v1.layers.conv3d用法及代碼示例
- Python tf.Variable.__lt__用法及代碼示例
- Python tf.keras.metrics.Mean.merge_state用法及代碼示例
- Python tf.keras.layers.InputLayer用法及代碼示例
- Python tf.compat.v1.strings.length用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.where。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。