當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.switch_case用法及代碼示例


創建一個 switch/case 操作,即 integer-indexed 條件。

用法

tf.switch_case(
    branch_index, branch_fns, default=None, name='switch_case'
)

參數

  • branch_index 一個 int 張量,指定應該執行 branch_fns 中的哪一個。
  • branch_fns int 映射到可調用對象的 dict,或(int,可調用)對的 list,或隻是可調用對象的列表(在這種情況下,索引用作鍵)。每個可調用對象都必須返回一個匹配的張量結構。
  • default 返回張量結構的可選可調用對象。
  • name 此操作的名稱(可選)。

返回

  • branch_index 標識的可調用對象返回的張量,或者如果沒有鍵匹配並且提供了 default 則由 default 返回的張量,或者如果沒有提供 default 則由 max-keyed branch_fn 返回的張量。

拋出

  • TypeError 如果 branch_fns 不是列表/字典。
  • TypeError 如果 branch_fns 是一個列表,但不包含 2 元組或可調用對象。
  • TypeError 如果 fns[i] 對於任何 i 都不可調用,或者 default 不可調用。

另見tf.case

當恰好選擇一個分支時,此操作可能比tf.case 更有效。 tf.switch_case 更像是一個 C++ switch/case 語句而不是 tf.case ,它更像是一個 if/elif/elif/else 鏈。

branch_fns 參數或者是從 int 到可調用對象的字典,或者是(int,可調用)對的列表,或者隻是可調用對象的列表(在這種情況下,索引是隱含的鍵)。 branch_index Tensor 用於選擇 branch_fns 中具有匹配 int 鍵的元素,如果沒有匹配則回退到 default,如果沒有提供 default,則返回 max(keys)。 key 必須形成從 0len(branch_fns) - 1 的連續集。

tf.switch_case 支持 tf.nest 中實現的嵌套結構。所有可調用對象必須返回相同的(可能是嵌套的)列表、元組和/或命名元組的值結構。

例子:

偽代碼:

switch (branch_index) {  // c-style switch
  case 0:return 17;
  case 1:return 31;
  default:return -1;
}

或者

branches = {0:lambda:17, 1:lambda:31}
branches.get(branch_index, lambda:-1)()

表達式:

def f1():return tf.constant(17)
def f2():return tf.constant(31)
def f3():return tf.constant(-1)
r = tf.switch_case(branch_index, branch_fns={0:f1, 1:f2}, default=f3)
# Equivalent:tf.switch_case(branch_index, branch_fns={0:f1, 1:f2, 2:f3})

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.switch_case。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。