返回非零元素的索引,或多路复用 x
和 y
。
用法
tf.where(
condition, x=None, y=None, name=None
)
参数
-
condition
tf.Tensor
dtype bool 或任何数字 dtype。当提供x
和y
时,condition
必须具有 dtypebool
。 -
x
如果提供,一个与y
类型相同的张量,并且具有可通过condition
和y
广播的形状。 -
y
如果提供,一个与x
类型相同的张量,并且具有可通过condition
和x
广播的形状。 -
name
操作的名称(可选)。
返回
-
如果提供了
x
和y
:与x
和y
具有相同类型的Tensor
以及从condition
,x
和y
广播的形状。否则,形状为[tf.math.count_nonzero(condition), tf.rank(condition)]
的Tensor
。
抛出
-
ValueError
当x
或y
中的一个恰好是 non-None 时,或者这些形状并非都是可广播的。
此操作有两种模式:
- 返回非零元素的索引- 只有当
condition
提供的结果是int64
张量,其中每一行是非零元素的索引condition
.结果的形状是[tf.math.count_nonzero(condition), tf.rank(condition)]
. - 复用
x
和y
- 当两个x
和y
提供结果的形状为x
,y
, 和condition
一起播出。结果取自x
其中condition
非零或y
其中condition
为零。
1. 返回非零元素的索引
注意:在这种模式下,condition
可以具有 bool
的 dtype 或任何数字 dtype。
如果未提供x
和y
(均为无):
tf.where
将以形状为 [n, d]
的二维张量的形式返回非零的 condition
的索引,其中 n
是 condition
中非零元素的数量(tf.count_nonzero(condition)
),d
是condition
( tf.rank(condition)
) 的轴数。
索引以行优先顺序输出。 condition
可以有 dtype
或 tf.bool
或任何数字 dtype
。
这里 condition
是一个 1 轴 bool
张量,具有 2 个 True
值。结果的形状为[2,1]
tf.where([True, False, False, True]).numpy()
array([[0],
[3]])
这里 condition
是一个 2 轴整数张量,具有 3 个非零值。结果的形状为 [3, 2]
。
tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
array([[0, 0],
[1, 0],
[1, 2]])
这里 condition
是一个 3 轴浮点张量,有 5 个非零值。输出形状为 [5, 3]
。
float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
[[0, 0], [0, 0], [99, 0]]]
tf.where(float_tensor).numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
这些索引与tf.sparse.SparseTensor
用来表示条件张量的索引相同:
sparse = tf.sparse.from_dense(float_tensor)
sparse.indices.numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
如果实部或虚部中的任何一个非零,则认为复数非零:
tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
array([[1],
[2],
[3]])
2. 复用x
和y
注意:在这种模式下 condition
必须具有 bool
的数据类型。
如果还提供了 x
和 y
(两者都具有非 None 值),则 condition
张量充当掩码,选择是否应从 x
中获取输出中的相应元素/行(如果元素在 condition
是 True
)或 y
(如果它是 False
)。
结果的形状是通过将 condition
, x
和 y
的形状一起广播形成的。
当所有三个输入具有相同的大小时,每个都按元素处理。
tf.where([True, False, False, True],
[1, 2, 3, 4],
[100, 200, 300, 400]).numpy()
array([ 1, 200, 300, 4], dtype=int32)
广播有两个主要规则:
- 如果张量的轴比其他张量少,则将长度为 1 的轴添加到形状的左侧。
- 长度为 1 的轴被拉伸以匹配其他张量的对应轴。
一个长度为 1 的向量被拉伸以匹配其他向量:
tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
array([ 1, 100, 100, 4], dtype=int32)
扩展标量以匹配其他参数:
tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
array([[ 1, 100], [100, 4]], dtype=int32)
tf.where([[True, False], [False, True]], 1, 100).numpy()
array([[ 1, 100], [100, 1]], dtype=int32)
标量 condition
返回完整的 x
或 y
张量,并应用了广播。
tf.where(True, [1, 2, 3, 4], 100).numpy()
array([1, 2, 3, 4], dtype=int32)
tf.where(False, [1, 2, 3, 4], 100).numpy()
array([100, 100, 100, 100], dtype=int32)
对于广播的一个重要示例,这里 condition
的形状为 [3]
, x
的形状为 [3,3]
,而 y
的形状为 [3,1]
。广播首先将 condition
的形状扩展为 [1,3]
。最终的广播形状是 [3,3]
。 condition
将从 x
和 y
中选择列。由于 y
只有一列,因此 y
中的所有列都将相同。
tf.where([True, False, True],
x=[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
y=[[100],
[200],
[300]]
).numpy()
array([[ 1, 100, 3],
[ 4, 200, 6],
[ 7, 300, 9]], dtype=int32)
请注意,如果 tf.where
的任一分支的梯度生成 NaN
,则整个 tf.where
的梯度将为 NaN
。这是因为出于性能原因,tf.where
的梯度计算结合了两个分支。
一种解决方法是使用内部 tf.where
来确保函数没有渐近线,并通过将危险输入替换为安全输入来避免计算梯度为 NaN
的值。
取而代之的是,
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)
尽管从未使用过 1. / x
值,但当 x = 0
时,它的梯度是 NaN
。相反,我们应该用另一个tf.where
来保护它
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
safe_x = tf.where(tf.equal(x, 0.), 1., x)
y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)
也可以看看:
tf.sparse
- 第一种形式的tf.where
返回的索引在tf.sparse.SparseTensor
对象中很有用。tf.gather_nd
、tf.scatter_nd
和相关操作 - 给定从tf.where
返回的索引列表,可以使用scatter
和gather
系列操作在这些索引处获取值或插入值。tf.strings.length
-tf.string
不是condition
允许的 dtype。请改用字符串长度。
相关用法
- Python tf.while_loop用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.math.special.fresnel_cos用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.Variable.__lt__用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.compat.v1.strings.length用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.where。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。