當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.where用法及代碼示例


xy 返回元素,具體取決於 condition

用法

tf.compat.v1.where(
    condition, x=None, y=None, name=None
)

參數

  • condition bool 類型的 Tensor
  • x 一個可能與 condition 具有相同形狀的張量。如果 condition 是 rank 1,則 x 可能有更高的 rank,但它的第一個維度必須與 condition 的大小匹配。
  • y x 具有相同形狀和類型的 tensor
  • name 操作名稱(可選)

返回

  • x , y 具有相同類型和形狀的Tensor,如果它們不是None。否則,形狀為 Tensor(num_true, rank(condition))

拋出

  • ValueError xy 中的一個恰好為非無時。

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

此 API 與即刻執行和 tf.function 兼容。但是,這仍然是最初為 TF1 設計的遺留 API 端點。要遷移到fully-native TF2,請將其用法替換為tf.where,它直接向後兼容tf.compat.v1.where

但是,tf.compat.v1.wheretf.where 更具限製性,要求 xy 具有相同的形狀,並返回與 x , y 具有相同類型和形狀的 Tensor(如果它們都是非沒有)。

tf.where 將接受不同形狀的 x , y,隻要它們可以相互廣播並與 condition 一起廣播,並將返回一個 Tensor 形狀廣播來自 condition , xy

例如,以下內容適用於 tf.where 但不適用於 tf.compat.v1.where

tf.where([True, False, False, True], [1,2,3,4], [100])
<tf.Tensor:shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
dtype=int32)>
tf.where(True, [1,2,3,4], 100)
<tf.Tensor:shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
dtype=int32)>

如果 xy 都為 None,則此操作返回 condition 的真實元素的坐標。坐標以二維張量返回,其中第一維(行)表示真實元素的數量,第二維(列)表示真實元素的坐標。請記住,輸出張量的形狀可能會根據輸入中有多少真實值而有所不同。索引以行優先順序輸出。

如果兩者都不是None,則xy 必須具有相同的形狀。如果 xy 是標量,則 condition 張量必須是標量。如果 xy 是更高階的張量,則 condition 必須是大小與 x 的第一個維度匹配的向量,或者必須具有與 x 相同的形狀。

condition 張量充當掩碼,根據每個元素的值選擇輸出中的相應元素/行是否應取自 x(如果為真)或 y(如果為假)。

如果 condition 是一個向量並且 xy 是更高等級的矩陣,那麽它選擇從 xy 複製哪一行(外部維度)。如果 conditionxy 具有相同的形狀,則它選擇從 xy 複製哪個元素。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.where。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。