当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.switch_case用法及代码示例


创建一个 switch/case 操作,即 integer-indexed 条件。

用法

tf.switch_case(
    branch_index, branch_fns, default=None, name='switch_case'
)

参数

  • branch_index 一个 int 张量,指定应该执行 branch_fns 中的哪一个。
  • branch_fns int 映射到可调用对象的 dict,或(int,可调用)对的 list,或只是可调用对象的列表(在这种情况下,索引用作键)。每个可调用对象都必须返回一个匹配的张量结构。
  • default 返回张量结构的可选可调用对象。
  • name 此操作的名称(可选)。

返回

  • branch_index 标识的可调用对象返回的张量,或者如果没有键匹配并且提供了 default 则由 default 返回的张量,或者如果没有提供 default 则由 max-keyed branch_fn 返回的张量。

抛出

  • TypeError 如果 branch_fns 不是列表/字典。
  • TypeError 如果 branch_fns 是一个列表,但不包含 2 元组或可调用对象。
  • TypeError 如果 fns[i] 对于任何 i 都不可调用,或者 default 不可调用。

另见tf.case

当恰好选择一个分支时,此操作可能比tf.case 更有效。 tf.switch_case 更像是一个 C++ switch/case 语句而不是 tf.case ,它更像是一个 if/elif/elif/else 链。

branch_fns 参数或者是从 int 到可调用对象的字典,或者是(int,可调用)对的列表,或者只是可调用对象的列表(在这种情况下,索引是隐含的键)。 branch_index Tensor 用于选择 branch_fns 中具有匹配 int 键的元素,如果没有匹配则回退到 default,如果没有提供 default,则返回 max(keys)。 key 必须形成从 0len(branch_fns) - 1 的连续集。

tf.switch_case 支持 tf.nest 中实现的嵌套结构。所有可调用对象必须返回相同的(可能是嵌套的)列表、元组和/或命名元组的值结构。

例子:

伪代码:

switch (branch_index) {  // c-style switch
  case 0:return 17;
  case 1:return 31;
  default:return -1;
}

或者

branches = {0:lambda:17, 1:lambda:31}
branches.get(branch_index, lambda:-1)()

表达式:

def f1():return tf.constant(17)
def f2():return tf.constant(31)
def f3():return tf.constant(-1)
r = tf.switch_case(branch_index, branch_fns={0:f1, 1:f2}, default=f3)
# Equivalent:tf.switch_case(branch_index, branch_fns={0:f1, 1:f2, 2:f3})

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.switch_case。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。