将一个 python 函数包装到一个 TensorFlow 操作中,该操作会即刻地执行它。
用法
tf.py_function(
func, inp, Tout, name=None
)
参数
-
func
一个 Python 函数,接受inp
作为参数,并返回一个值(或值列表),其类型由Tout
说明。 -
inp
func
的输入参数。其元素为Tensor
或CompositeTensors
的列表(例如tf.RaggedTensor
);或单个Tensor
或CompositeTensor
。 -
Tout
返回的值的类型func
.以下之一。如果
func
返回Tensor
(或可以转换为张量的值):该值的tf.DType
。如果
func
返回CompositeTensor
:该值的tf.TypeSpec
。如果
func
返回None
:空列表([]
)。如果
func
返回Tensor
和CompositeTensor
值的列表:每个值对应的tf.DType
和tf.TypeSpec
列表。
-
name
操作的名称(可选)。
返回
-
由
func
:aTensor
,CompositeTensor
或Tensor
和CompositeTensor
的列表计算的值;如果func
返回None
,则为空列表。
此函数允许将 TensorFlow 图中的计算表示为 Python 函数。特别是,它将 Python 函数 func
包装在一次可微分的 TensorFlow 操作中,该操作在启用即刻执行的情况下执行它。因此,tf.py_function
可以使用 Python 构造(if
, while
, for
等)而不是 TensorFlow 控制流构造(tf.cond
、tf.while_loop
)来表达控制流。例如,您可以使用tf.py_function
来实现日志集线器函数:
def log_huber(x, m):
if tf.abs(x) <= m:
return x**2
else:
return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
x = tf.compat.v1.placeholder(tf.float32)
m = tf.compat.v1.placeholder(tf.float32)
y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
dy_dx = tf.gradients(y, x)[0]
with tf.compat.v1.Session() as sess:
# The session executes `log_huber` eagerly. Given the feed values below,
# it will take the first branch, so `y` evaluates to 1.0 and
# `dy_dx` evaluates to 2.0.
y, dy_dx = sess.run([y, dy_dx], feed_dict={x:1.0, m:2.0})
您还可以使用tf.py_function
在运行时使用 Python 工具调试模型,即,您可以隔离要调试的代码部分,将它们包装在 Python 函数中并根据需要插入 pdb
跟踪点或打印语句,以及将这些函数包装在 tf.py_function
中。
有关 Eager Execution 的更多信息,请参阅 Eager 指南。
tf.py_function
在本质上与 tf.compat.v1.py_func
相似,但与后者不同的是,前者允许您在包装的 Python 函数中使用 TensorFlow 操作。特别是,虽然tf.compat.v1.py_func
仅在 CPU 上运行并包装以 NumPy 数组作为输入并返回 NumPy 数组作为输出的函数,但 tf.py_function
可以放置在 GPU 上并包装以张量作为输入的函数,在它们的主体中执行 TensorFlow 操作,并返回张量作为输出。
与 tf.compat.v1.py_func
一样,tf.py_function
在序列化和分发方掩码有以下限制:
函数体(即
func
)不会在GraphDef
中序列化。因此,如果您需要序列化模型并在不同的环境中恢复它,则不应使用此函数。该操作必须在与调用
tf.py_function()
的 Python 程序相同的地址空间中运行。如果您使用分布式 TensorFlow,则必须在与调用tf.py_function()
的程序相同的进程中运行tf.distribute.Server
,并且必须将创建的操作固定到该服务器中的设备(例如,使用with tf.device():
)。
相关用法
- Python tf.print用法及代码示例
- Python tf.profiler.experimental.start用法及代码示例
- Python tf.pad用法及代码示例
- Python tf.parallel_stack用法及代码示例
- Python tf.profiler.experimental.Trace.set_metadata用法及代码示例
- Python tf.profiler.experimental.client.trace用法及代码示例
- Python tf.profiler.experimental.Trace用法及代码示例
- Python tf.profiler.experimental.client.monitor用法及代码示例
- Python tf.profiler.experimental.Profile用法及代码示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代码示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代码示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代码示例
- Python tf.summary.scalar用法及代码示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代码示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代码示例
- Python tf.raw_ops.TPUReplicatedInput用法及代码示例
- Python tf.raw_ops.Bitcast用法及代码示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代码示例
- Python tf.compat.v1.Variable.eval用法及代码示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.py_function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。