將一個 python 函數包裝到一個 TensorFlow 操作中,該操作會即刻地執行它。
用法
tf.py_function(
func, inp, Tout, name=None
)
參數
-
func
一個 Python 函數,接受inp
作為參數,並返回一個值(或值列表),其類型由Tout
說明。 -
inp
func
的輸入參數。其元素為Tensor
或CompositeTensors
的列表(例如tf.RaggedTensor
);或單個Tensor
或CompositeTensor
。 -
Tout
返回的值的類型func
.以下之一。如果
func
返回Tensor
(或可以轉換為張量的值):該值的tf.DType
。如果
func
返回CompositeTensor
:該值的tf.TypeSpec
。如果
func
返回None
:空列表([]
)。如果
func
返回Tensor
和CompositeTensor
值的列表:每個值對應的tf.DType
和tf.TypeSpec
列表。
-
name
操作的名稱(可選)。
返回
-
由
func
:aTensor
,CompositeTensor
或Tensor
和CompositeTensor
的列表計算的值;如果func
返回None
,則為空列表。
此函數允許將 TensorFlow 圖中的計算表示為 Python 函數。特別是,它將 Python 函數 func
包裝在一次可微分的 TensorFlow 操作中,該操作在啟用即刻執行的情況下執行它。因此,tf.py_function
可以使用 Python 構造(if
, while
, for
等)而不是 TensorFlow 控製流構造(tf.cond
、tf.while_loop
)來表達控製流。例如,您可以使用tf.py_function
來實現日誌集線器函數:
def log_huber(x, m):
if tf.abs(x) <= m:
return x**2
else:
return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))
x = tf.compat.v1.placeholder(tf.float32)
m = tf.compat.v1.placeholder(tf.float32)
y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
dy_dx = tf.gradients(y, x)[0]
with tf.compat.v1.Session() as sess:
# The session executes `log_huber` eagerly. Given the feed values below,
# it will take the first branch, so `y` evaluates to 1.0 and
# `dy_dx` evaluates to 2.0.
y, dy_dx = sess.run([y, dy_dx], feed_dict={x:1.0, m:2.0})
您還可以使用tf.py_function
在運行時使用 Python 工具調試模型,即,您可以隔離要調試的代碼部分,將它們包裝在 Python 函數中並根據需要插入 pdb
跟蹤點或打印語句,以及將這些函數包裝在 tf.py_function
中。
有關 Eager Execution 的更多信息,請參閱 Eager 指南。
tf.py_function
在本質上與 tf.compat.v1.py_func
相似,但與後者不同的是,前者允許您在包裝的 Python 函數中使用 TensorFlow 操作。特別是,雖然tf.compat.v1.py_func
僅在 CPU 上運行並包裝以 NumPy 數組作為輸入並返回 NumPy 數組作為輸出的函數,但 tf.py_function
可以放置在 GPU 上並包裝以張量作為輸入的函數,在它們的主體中執行 TensorFlow 操作,並返回張量作為輸出。
與 tf.compat.v1.py_func
一樣,tf.py_function
在序列化和分發方掩碼有以下限製:
函數體(即
func
)不會在GraphDef
中序列化。因此,如果您需要序列化模型並在不同的環境中恢複它,則不應使用此函數。該操作必須在與調用
tf.py_function()
的 Python 程序相同的地址空間中運行。如果您使用分布式 TensorFlow,則必須在與調用tf.py_function()
的程序相同的進程中運行tf.distribute.Server
,並且必須將創建的操作固定到該服務器中的設備(例如,使用with tf.device():
)。
相關用法
- Python tf.print用法及代碼示例
- Python tf.profiler.experimental.start用法及代碼示例
- Python tf.pad用法及代碼示例
- Python tf.parallel_stack用法及代碼示例
- Python tf.profiler.experimental.Trace.set_metadata用法及代碼示例
- Python tf.profiler.experimental.client.trace用法及代碼示例
- Python tf.profiler.experimental.Trace用法及代碼示例
- Python tf.profiler.experimental.client.monitor用法及代碼示例
- Python tf.profiler.experimental.Profile用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.summary.scalar用法及代碼示例
- Python tf.linalg.LinearOperatorFullMatrix.matvec用法及代碼示例
- Python tf.linalg.LinearOperatorToeplitz.solve用法及代碼示例
- Python tf.raw_ops.TPUReplicatedInput用法及代碼示例
- Python tf.raw_ops.Bitcast用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.py_function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。