當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.py_function用法及代碼示例


將一個 python 函數包裝到一個 TensorFlow 操作中,該操作會即刻地執行它。

用法

tf.py_function(
    func, inp, Tout, name=None
)

參數

  • func 一個 Python 函數,接受 inp 作為參數,並返回一個值(或值列表),其類型由 Tout 說明。
  • inp func 的輸入參數。其元素為 TensorCompositeTensors 的列表(例如 tf.RaggedTensor );或單個 TensorCompositeTensor
  • Tout 返回的值的類型func.以下之一。
    • 如果 func 返回 Tensor(或可以轉換為張量的值):該值的 tf.DType

    • 如果 func 返回 CompositeTensor :該值的 tf.TypeSpec

    • 如果 func 返回 None :空列表( [] )。

    • 如果 func 返回 TensorCompositeTensor 值的列表:每個值對應的 tf.DTypetf.TypeSpec 列表。

  • name 操作的名稱(可選)。

返回

  • func :a Tensor , CompositeTensorTensorCompositeTensor 的列表計算的值;如果 func 返回 None ,則為空列表。

此函數允許將 TensorFlow 圖中的計算表示為 Python 函數。特別是,它將 Python 函數 func 包裝在一次可微分的 TensorFlow 操作中,該操作在啟用即刻執行的情況下執行它。因此,tf.py_function 可以使用 Python 構造(if , while , for 等)而不是 TensorFlow 控製流構造(tf.condtf.while_loop)來表達控製流。例如,您可以使用tf.py_function 來實現日誌集線器函數:

def log_huber(x, m):
  if tf.abs(x) <= m:
    return x**2
  else:
    return m**2 * (1 - 2 * tf.math.log(m) + tf.math.log(x**2))

x = tf.compat.v1.placeholder(tf.float32)
m = tf.compat.v1.placeholder(tf.float32)

y = tf.py_function(func=log_huber, inp=[x, m], Tout=tf.float32)
dy_dx = tf.gradients(y, x)[0]

with tf.compat.v1.Session() as sess:
  # The session executes `log_huber` eagerly. Given the feed values below,
  # it will take the first branch, so `y` evaluates to 1.0 and
  # `dy_dx` evaluates to 2.0.
  y, dy_dx = sess.run([y, dy_dx], feed_dict={x:1.0, m:2.0})

您還可以使用tf.py_function 在運行時使用 Python 工具調試模型,即,您可以隔離要調試的代碼部分,將它們包裝在 Python 函數中並根據需要插入 pdb 跟蹤點或打印語句,以及將這些函數包裝在 tf.py_function 中。

有關 Eager Execution 的更多信息,請參閱 Eager 指南。

tf.py_function 在本質上與 tf.compat.v1.py_func 相似,但與後者不同的是,前者允許您在包裝的 Python 函數中使用 TensorFlow 操作。特別是,雖然tf.compat.v1.py_func 僅在 CPU 上運行並包裝以 NumPy 數組作為輸入並返回 NumPy 數組作為輸出的函數,但 tf.py_function 可以放置在 GPU 上並包裝以張量作為輸入的函數,在它們的主體中執行 TensorFlow 操作,並返回張量作為輸出。

tf.compat.v1.py_func 一樣,tf.py_function 在序列化和分發方掩碼有以下限製:

  • 函數體(即 func )不會在 GraphDef 中序列化。因此,如果您需要序列化模型並在不同的環境中恢複它,則不應使用此函數。

  • 該操作必須在與調用 tf.py_function() 的 Python 程序相同的地址空間中運行。如果您使用分布式 TensorFlow,則必須在與調用 tf.py_function() 的程序相同的進程中運行 tf.distribute.Server,並且必須將創建的操作固定到該服務器中的設備(例如,使用 with tf.device(): )。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.py_function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。