当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.cond用法及代码示例


如果谓词 pred 为真,则返回 true_fn() 否则 false_fn()

用法

tf.cond(
    pred, true_fn=None, false_fn=None, name=None
)

参数

  • pred 一个标量,确定是否返回 true_fnfalse_fn 的结果。
  • true_fn 如果 pred 为真,则要执行的调用。
  • false_fn 如果 pred 为假,则要执行的可调用。
  • name 返回的张量的可选名称前缀。

返回

  • 调用 true_fnfalse_fn 返回的张量。如果可调用对象返回单例列表,则从列表中提取元素。

抛出

  • TypeError 如果 true_fnfalse_fn 不可调用。
  • ValueError 如果 true_fnfalse_fn 不返回相同数量的张量,或者返回不同类型的张量。

true_fnfalse_fn 都返回输出张量的列表。 true_fnfalse_fn 必须具有相同的非零数量和类型的输出。

警告:无论在运行时选择哪个分支,在 true_fnfalse_fn 之外创建的任何张量或操作都将被执行。

尽管这种行为与 TensorFlow 的数据流模型是一致的,但它经常让那些期待更惰性语义的用户感到惊讶。考虑以下简单程序:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda:tf.add(x, z), lambda:tf.square(y))

如果 x < y ,将执行 tf.add 操作,而不会执行 tf.square 操作。由于cond 的至少一个分支需要z,因此始终无条件地执行tf.multiply 操作。

注意cond调用true_fnfalse_fn 恰好一次(在调用内cond, 并且在Session.run())。cond将在创建过程中创建的图形片段缝合在一起true_fnfalse_fn调用一些额外的图节点,以确保根据值执行正确的分支pred.

tf.cond 支持在 tensorflow.python.util.nest 中实现的嵌套结构。 true_fnfalse_fn 都必须返回相同的(可能是嵌套的)列表、元组和/或命名元组的值结构。单例列表和元组是唯一的异常:当 true_fn 和/或 false_fn 返回时,它们被隐式解包为单个值。

注意:"directly" 使用在其外部的 cond 分支内创建的张量是非法的,例如通过在 python 状态中存储对分支张量的引用。如果您需要使用在分支函数中创建的张量,则应将其作为分支函数的输出返回,并改用 tf.cond 的输出。

例子:

x = tf.constant(2)
y = tf.constant(5)
def f1():return tf.multiply(x, 17)
def f2():return tf.add(y, 23)
r = tf.cond(tf.less(x, y), f1, f2)
# r is set to f1().
# Operations in f2 (e.g., tf.add) are not executed.

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.cond。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。