将一个 python 函数包装到一个 TensorFlow 操作中,该操作会即刻地执行它。
用法
tf.py_function(
func, inp, Tout, name=None
)参数
-
func一个 Python 函数,接受inp作为参数,并返回一个值(或值列表),其类型由Tout说明。 -
inpfunc的输入参数。其元素为Tensor或CompositeTensors的列表(例如tf.RaggedTensor);或单个Tensor或CompositeTensor。 -
Tout返回的值的类型func.以下之一。如果
func返回Tensor(或可以转换为张量的值):该值的tf.DType。如果
func返回CompositeTensor:该值的tf.TypeSpec。如果
func返回None:空列表([])。如果
func返回Tensor和CompositeTensor值的列表:每个值对应的tf.DType和tf.TypeSpec列表。
-
name操作的名称(可选)。
返回
-
由
func:aTensor,CompositeTensor或Tensor和CompositeTensor的列表计算的值;如果func返回None,则为空列表。
此函数允许将 TensorFlow 图中的计算表示为 Python 函数。特别是,它将 Python 函数 func 包装在一次可微分的 TensorFlow 操作中,该操作在启用即刻执行的情况下执行它。因此,tf.py_function 可以使用 Python 构造(if , while , for 等)而不是 TensorFlow 控制流构造(tf.cond、tf.while_loop)来表达控制流。例如,您可以使用tf.py_function 来实现日志集线器函数:
def log_huber(x, m):
if tf.abs(x) <= m:
return x**2
else:
return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
x = tf.compat.v1.placeholder(tf.float32)
m = tf.compat.v1.placeholder(tf.float32)
y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
dy_dx = tf.gradients(y, x)[0]
with tf.compat.v1.Session() as sess:
# The session executes `log_huber` eagerly. Given the feed values below,
# it will take the first branch, so `y` evaluates to 1.0 and
# `dy_dx` evaluates to 2.0.
y, dy_dx = sess.run([y, dy_dx], feed_dict={x:1.0, m:2.0})
您还可以使用tf.py_function 在运行时使用 Python 工具调试模型,即,您可以隔离要调试的代码部分,将它们包装在 Python 函数中并根据需要插入 pdb 跟踪点或打印语句,以及将这些函数包装在 tf.py_function 中。
有关 Eager Execution 的更多信息,请参阅 Eager 指南。
tf.py_function 在本质上与 tf.compat.v1.py_func 相似,但与后者不同的是,前者允许您在包装的 Python 函数中使用 TensorFlow 操作。特别是,虽然tf.compat.v1.py_func 仅在 CPU 上运行并包装以 NumPy 数组作为输入并返回 NumPy 数组作为输出的函数,但 tf.py_function 可以放置在 GPU 上并包装以张量作为输入的函数,在它们的主体中执行 TensorFlow 操作,并返回张量作为输出。
与 tf.compat.v1.py_func 一样,tf.py_function 在序列化和分发方掩码有以下限制:
函数体(即
func)不会在GraphDef中序列化。因此,如果您需要序列化模型并在不同的环境中恢复它,则不应使用此函数。该操作必须在与调用
tf.py_function()的 Python 程序相同的地址空间中运行。如果您使用分布式 TensorFlow,则必须在与调用tf.py_function()的程序相同的进程中运行tf.distribute.Server,并且必须将创建的操作固定到该服务器中的设备(例如,使用with tf.device():)。
相关用法
- Python tf.print用法及代码示例
- Python tf.profiler.experimental.start用法及代码示例
- Python tf.pad用法及代码示例
- Python tf.parallel_stack用法及代码示例
- Python tf.profiler.experimental.Trace.set_metadata用法及代码示例
- Python tf.profiler.experimental.client.trace用法及代码示例
- Python tf.profiler.experimental.Trace用法及代码示例
- Python tf.profiler.experimental.client.monitor用法及代码示例
- Python tf.profiler.experimental.Profile用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.py_function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。
