当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.where用法及代码示例


返回非零元素的索引,或多路复用 xy

用法

tf.where(
    condition, x=None, y=None, name=None
)

参数

  • condition tf.Tensor dtype bool 或任何数字 dtype。当提供 xy 时,condition 必须具有 dtype bool
  • x 如果提供,一个与 y 类型相同的张量,并且具有可通过 conditiony 广播的形状。
  • y 如果提供,一个与 x 类型相同的张量,并且具有可通过 conditionx 广播的形状。
  • name 操作的名称(可选)。

返回

  • 如果提供了 xy:与 xy 具有相同类型的 Tensor 以及从 condition , xy 广播的形状。否则,形状为 [tf.math.count_nonzero(condition), tf.rank(condition)]Tensor

抛出

  • ValueError xy 中的一个恰好是 non-None 时,或者这些形状并非都是可广播的。

此操作有两种模式:

  1. 返回非零元素的索引- 只有当condition提供的结果是int64张量,其中每一行是非零元素的索引condition.结果的形状是[tf.math.count_nonzero(condition), tf.rank(condition)].
  2. 复用xy- 当两个xy提供结果的形状为x,y, 和condition一起播出。结果取自x其中condition非零或y其中condition为零。

1. 返回非零元素的索引

注意:在这种模式下,condition 可以具有 bool 的 dtype 或任何数字 dtype。

如果未提供xy(均为无):

tf.where 将以形状为 [n, d] 的二维张量的形式返回非零的 condition 的索引,其中 ncondition 中非零元素的数量(tf.count_nonzero(condition) ),dcondition ( tf.rank(condition) ) 的轴数。

索引以行优先顺序输出。 condition 可以有 dtypetf.bool 或任何数字 dtype

这里 condition 是一个 1 轴 bool 张量,具有 2 个 True 值。结果的形状为[2,1]

tf.where([True, False, False, True]).numpy()
array([[0],
       [3]])

这里 condition 是一个 2 轴整数张量,具有 3 个非零值。结果的形状为 [3, 2]

tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
array([[0, 0],
       [1, 0],
       [1, 2]])

这里 condition 是一个 3 轴浮点张量,有 5 个非零值。输出形状为 [5, 3]

float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
                [[0,   0], [0,   0], [99,    0]]]
tf.where(float_tensor).numpy()
array([[0, 0, 0],
       [0, 1, 1],
       [0, 2, 0],
       [0, 2, 1],
       [1, 2, 0]])

这些索引与tf.sparse.SparseTensor 用来表示条件张量的索引相同:

sparse = tf.sparse.from_dense(float_tensor)
sparse.indices.numpy()
array([[0, 0, 0],
       [0, 1, 1],
       [0, 2, 0],
       [0, 2, 1],
       [1, 2, 0]])

如果实部或虚部中的任何一个非零,则认为复数非零:

tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
array([[1],
       [2],
       [3]])

2. 复用xy

注意:在这种模式下 condition 必须具有 bool 的数据类型。

如果还提供了 xy(两者都具有非 None 值),则 condition 张量充当掩码,选择是否应从 x 中获取输出中的相应元素/行(如果元素在 conditionTrue )或 y (如果它是 False )。

结果的形状是通过将 condition , xy 的形状一起广播形成的。

当所有三个输入具有相同的大小时,每个都按元素处理。

tf.where([True, False, False, True],
         [1, 2, 3, 4],
         [100, 200, 300, 400]).numpy()
array([  1, 200, 300,   4], dtype=int32)

广播有两个主要规则:

  1. 如果张量的轴比其他张量少,则将长度为 1 的轴添加到形状的左侧。
  2. 长度为 1 的轴被拉伸以匹配其他张量的对应轴。

一个长度为 1 的向量被拉伸以匹配其他向量:

tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
array([  1, 100, 100,   4], dtype=int32)

扩展标量以匹配其他参数:

tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
array([[  1, 100], [100,   4]], dtype=int32)
tf.where([[True, False], [False, True]], 1, 100).numpy()
array([[  1, 100], [100,   1]], dtype=int32)

标量 condition 返回完整的 xy 张量,并应用了广播。

tf.where(True, [1, 2, 3, 4], 100).numpy()
array([1, 2, 3, 4], dtype=int32)
tf.where(False, [1, 2, 3, 4], 100).numpy()
array([100, 100, 100, 100], dtype=int32)

对于广播的一个重要示例,这里 condition 的形状为 [3] , x 的形状为 [3,3] ,而 y 的形状为 [3,1] 。广播首先将 condition 的形状扩展为 [1,3] 。最终的广播形状是 [3,3]condition 将从 xy 中选择列。由于 y 只有一列,因此 y 中的所有列都将相同。

tf.where([True, False, True],
         x=[[1, 2, 3],
            [4, 5, 6],
            [7, 8, 9]],
         y=[[100],
            [200],
            [300]]
).numpy()
array([[ 1, 100, 3],
       [ 4, 200, 6],
       [ 7, 300, 9]], dtype=int32)

请注意,如果 tf.where 的任一分支的梯度生成 NaN ,则整个 tf.where 的梯度将为 NaN 。这是因为出于性能原因,tf.where 的梯度计算结合了两个分支。

一种解决方法是使用内部 tf.where 来确保函数没有渐近线,并通过将危险输入替换为安全输入来避免计算梯度为 NaN 的值。

取而代之的是,

x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
  tape.watch(x)
  y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)

尽管从未使用过 1. / x 值,但当 x = 0 时,它的梯度是 NaN 。相反,我们应该用另一个tf.where来保护它

x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
  tape.watch(x)
  safe_x = tf.where(tf.equal(x, 0.), 1., x)
  y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)

也可以看看:

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.where。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。