返回非零元素的索引,或多路复用 x 和 y 。
用法
tf.where(
condition, x=None, y=None, name=None
)参数
-
conditiontf.Tensordtype bool 或任何数字 dtype。当提供x和y时,condition必须具有 dtypebool。 -
x如果提供,一个与y类型相同的张量,并且具有可通过condition和y广播的形状。 -
y如果提供,一个与x类型相同的张量,并且具有可通过condition和x广播的形状。 -
name操作的名称(可选)。
返回
-
如果提供了
x和y:与x和y具有相同类型的Tensor以及从condition,x和y广播的形状。否则,形状为[tf.math.count_nonzero(condition), tf.rank(condition)]的Tensor。
抛出
-
ValueError当x或y中的一个恰好是 non-None 时,或者这些形状并非都是可广播的。
此操作有两种模式:
- 返回非零元素的索引- 只有当
condition提供的结果是int64张量,其中每一行是非零元素的索引condition.结果的形状是[tf.math.count_nonzero(condition), tf.rank(condition)]. - 复用
x和y- 当两个x和y提供结果的形状为x,y, 和condition一起播出。结果取自x其中condition非零或y其中condition为零。
1. 返回非零元素的索引
注意:在这种模式下,condition 可以具有 bool 的 dtype 或任何数字 dtype。
如果未提供x 和y(均为无):
tf.where 将以形状为 [n, d] 的二维张量的形式返回非零的 condition 的索引,其中 n 是 condition 中非零元素的数量(tf.count_nonzero(condition) ),d 是condition ( tf.rank(condition) ) 的轴数。
索引以行优先顺序输出。 condition 可以有 dtype 或 tf.bool 或任何数字 dtype 。
这里 condition 是一个 1 轴 bool 张量,具有 2 个 True 值。结果的形状为[2,1]
tf.where([True, False, False, True]).numpy()
array([[0],
[3]])
这里 condition 是一个 2 轴整数张量,具有 3 个非零值。结果的形状为 [3, 2] 。
tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
array([[0, 0],
[1, 0],
[1, 2]])
这里 condition 是一个 3 轴浮点张量,有 5 个非零值。输出形状为 [5, 3] 。
float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
[[0, 0], [0, 0], [99, 0]]]
tf.where(float_tensor).numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
这些索引与tf.sparse.SparseTensor 用来表示条件张量的索引相同:
sparse = tf.sparse.from_dense(float_tensor)
sparse.indices.numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
如果实部或虚部中的任何一个非零,则认为复数非零:
tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
array([[1],
[2],
[3]])
2. 复用x和y
注意:在这种模式下 condition 必须具有 bool 的数据类型。
如果还提供了 x 和 y(两者都具有非 None 值),则 condition 张量充当掩码,选择是否应从 x 中获取输出中的相应元素/行(如果元素在 condition 是 True )或 y (如果它是 False )。
结果的形状是通过将 condition , x 和 y 的形状一起广播形成的。
当所有三个输入具有相同的大小时,每个都按元素处理。
tf.where([True, False, False, True],
[1, 2, 3, 4],
[100, 200, 300, 400]).numpy()
array([ 1, 200, 300, 4], dtype=int32)
广播有两个主要规则:
- 如果张量的轴比其他张量少,则将长度为 1 的轴添加到形状的左侧。
- 长度为 1 的轴被拉伸以匹配其他张量的对应轴。
一个长度为 1 的向量被拉伸以匹配其他向量:
tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
array([ 1, 100, 100, 4], dtype=int32)
扩展标量以匹配其他参数:
tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
array([[ 1, 100], [100, 4]], dtype=int32)
tf.where([[True, False], [False, True]], 1, 100).numpy()
array([[ 1, 100], [100, 1]], dtype=int32)
标量 condition 返回完整的 x 或 y 张量,并应用了广播。
tf.where(True, [1, 2, 3, 4], 100).numpy()
array([1, 2, 3, 4], dtype=int32)
tf.where(False, [1, 2, 3, 4], 100).numpy()
array([100, 100, 100, 100], dtype=int32)
对于广播的一个重要示例,这里 condition 的形状为 [3] , x 的形状为 [3,3] ,而 y 的形状为 [3,1] 。广播首先将 condition 的形状扩展为 [1,3] 。最终的广播形状是 [3,3] 。 condition 将从 x 和 y 中选择列。由于 y 只有一列,因此 y 中的所有列都将相同。
tf.where([True, False, True],
x=[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
y=[[100],
[200],
[300]]
).numpy()
array([[ 1, 100, 3],
[ 4, 200, 6],
[ 7, 300, 9]], dtype=int32)
请注意,如果 tf.where 的任一分支的梯度生成 NaN ,则整个 tf.where 的梯度将为 NaN 。这是因为出于性能原因,tf.where 的梯度计算结合了两个分支。
一种解决方法是使用内部 tf.where 来确保函数没有渐近线,并通过将危险输入替换为安全输入来避免计算梯度为 NaN 的值。
取而代之的是,
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)
尽管从未使用过 1. / x 值,但当 x = 0 时,它的梯度是 NaN 。相反,我们应该用另一个tf.where来保护它
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
safe_x = tf.where(tf.equal(x, 0.), 1., x)
y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)
也可以看看:
tf.sparse- 第一种形式的tf.where返回的索引在tf.sparse.SparseTensor对象中很有用。tf.gather_nd、tf.scatter_nd和相关操作 - 给定从tf.where返回的索引列表,可以使用scatter和gather系列操作在这些索引处获取值或插入值。tf.strings.length-tf.string不是condition允许的 dtype。请改用字符串长度。
相关用法
- Python tf.while_loop用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
- Python tf.distribute.OneDeviceStrategy.experimental_distribute_values_from_function用法及代码示例
- Python tf.math.special.fresnel_cos用法及代码示例
- Python tf.keras.applications.inception_resnet_v2.preprocess_input用法及代码示例
- Python tf.compat.v1.layers.conv3d用法及代码示例
- Python tf.Variable.__lt__用法及代码示例
- Python tf.keras.metrics.Mean.merge_state用法及代码示例
- Python tf.keras.layers.InputLayer用法及代码示例
- Python tf.compat.v1.strings.length用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.where。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
