当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.py_function用法及代码示例


将一个 python 函数包装到一个 TensorFlow 操作中,该操作会即刻地执行它。

用法

tf.py_function(
    func, inp, Tout, name=None
)

参数

  • func 一个 Python 函数,接受 inp 作为参数,并返回一个值(或值列表),其类型由 Tout 说明。
  • inp func 的输入参数。其元素为 TensorCompositeTensors 的列表(例如 tf.RaggedTensor );或单个 TensorCompositeTensor
  • Tout 返回的值的类型func.以下之一。
    • 如果 func 返回 Tensor(或可以转换为张量的值):该值的 tf.DType

    • 如果 func 返回 CompositeTensor :该值的 tf.TypeSpec

    • 如果 func 返回 None :空列表( [] )。

    • 如果 func 返回 TensorCompositeTensor 值的列表:每个值对应的 tf.DTypetf.TypeSpec 列表。

  • name 操作的名称(可选)。

返回

  • func :a Tensor , CompositeTensorTensorCompositeTensor 的列表计算的值;如果 func 返回 None ,则为空列表。

此函数允许将 TensorFlow 图中的计算表示为 Python 函数。特别是,它将 Python 函数 func 包装在一次可微分的 TensorFlow 操作中,该操作在启用即刻执行的情况下执行它。因此,tf.py_function 可以使用 Python 构造(if , while , for 等)而不是 TensorFlow 控制流构造(tf.condtf.while_loop)来表达控制流。例如,您可以使用tf.py_function 来实现日志集线器函数:

def log_huber(x, m):
  if tf.abs(x) <= m:
    return x**2
  else:
    return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))

x = tf.compat.v1.placeholder(tf.float32)
m = tf.compat.v1.placeholder(tf.float32)

y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
dy_dx = tf.gradients(y, x)[0]

with tf.compat.v1.Session() as sess:
  # The session executes `log_huber` eagerly. Given the feed values below,
  # it will take the first branch, so `y` evaluates to 1.0 and
  # `dy_dx` evaluates to 2.0.
  y, dy_dx = sess.run([y, dy_dx], feed_dict={x:1.0, m:2.0})

您还可以使用tf.py_function 在运行时使用 Python 工具调试模型,即,您可以隔离要调试的代码部分,将它们包装在 Python 函数中并根据需要插入 pdb 跟踪点或打印语句,以及将这些函数包装在 tf.py_function 中。

有关 Eager Execution 的更多信息,请参阅 Eager 指南。

tf.py_function 在本质上与 tf.compat.v1.py_func 相似,但与后者不同的是,前者允许您在包装的 Python 函数中使用 TensorFlow 操作。特别是,虽然tf.compat.v1.py_func 仅在 CPU 上运行并包装以 NumPy 数组作为输入并返回 NumPy 数组作为输出的函数,但 tf.py_function 可以放置在 GPU 上并包装以张量作为输入的函数,在它们的主体中执行 TensorFlow 操作,并返回张量作为输出。

tf.compat.v1.py_func 一样,tf.py_function 在序列化和分发方掩码有以下限制:

  • 函数体(即 func )不会在 GraphDef 中序列化。因此,如果您需要序列化模型并在不同的环境中恢复它,则不应使用此函数。

  • 该操作必须在与调用 tf.py_function() 的 Python 程序相同的地址空间中运行。如果您使用分布式 TensorFlow,则必须在与调用 tf.py_function() 的程序相同的进程中运行 tf.distribute.Server,并且必须将创建的操作固定到该服务器中的设备(例如,使用 with tf.device(): )。

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.py_function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。