當前位置: 首頁>>編程示例 >>用法及示例精選 >>正文


Python tf.compat.v1.py_func用法及代碼示例

包裝一個 python 函數並將其用作 TensorFlow 操作。

用法

tf.compat.v1.py_func(
    func, inp, Tout, stateful=True, name=None
)

參數

  • func 一個 Python 函數,它接受 ndarray 對象作為參數並返回 ndarray 對象列表(或單個 ndarray )。此函數必須接受與 inp 中的張量一樣多的參數,並且這些參數類型將匹配 inp 中相應的 tf.Tensor 對象。返回的 ndarray 必須與定義的數量和類型相匹配 Tout 。重要提示:func 的輸入和輸出 numpy ndarray 不保證是副本。在某些情況下,它們的底層內存將與相應的 TensorFlow 張量共享。在沒有顯式 (np.) 複製的情況下,在 python 數據結構中就地修改或存儲 func 輸入或返回值可能會產生不確定的後果。
  • inp Tensor 對象的列表。
  • Tout tensorflow 數據類型的列表或元組或單個 tensorflow 數據類型(如果隻有一個),指示 func 返回的內容。
  • stateful (布爾值。)如果為 True,則該函數應被視為有狀態。如果一個函數是無狀態的,當給定相同的輸入時,它將返回相同的輸出並且沒有可觀察到的副作用。諸如公共子表達式消除之類的優化僅在無狀態操作上執行。
  • name 操作的名稱(可選)。

返回

  • Tensor 的列表或 func 計算的單個 Tensor

遷移到 TF2

警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。

此名稱在 TF2 中已棄用並刪除,但 tf.numpy_function 是 near-exact 替換,隻需刪除 stateful 參數(所有 tf.numpy_function 調用都被視為有狀態)。它與即刻執行和 tf.function 兼容。

tf.py_function 是一個接近但不是完全替代品,將 TensorFlow 張量傳遞給包裝函數而不是 NumPy 數組,後者提供梯度並可以利用加速器。

前:

def fn_using_numpy(x):
  x[0] = 0.
  return x
tf.compat.v1.py_func(fn_using_numpy, inp=[tf.constant([1., 2.])],
    Tout=tf.float32, stateful=False)
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>

後:

tf.numpy_function(fn_using_numpy, inp=[tf.constant([1., 2.])],
    Tout=tf.float32)
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>

給定一個 python 函數 func ,它將 numpy 數組作為其參數並返回 numpy 數組作為其輸出,將此函數包裝為 TensorFlow 圖中的一個操作。以下代碼段構建了一個簡單的 TensorFlow 圖,該圖調用 np.sinh() NumPy 函數作為圖中的操作:

def my_func(x):
  # x will be a numpy array with the contents of the placeholder below
  return np.sinh(x)
input = tf.compat.v1.placeholder(tf.float32)
y = tf.compat.v1.py_func(my_func, [input], tf.float32)

注意:tf.compat.v1.py_func() 操作具有以下已知限製:

  • 函數體(即 func )不會在 GraphDef 中序列化。因此,如果您需要序列化模型並在不同的環境中恢複它,則不應使用此函數。

  • 該操作必須在與調用 tf.compat.v1.py_func() 的 Python 程序相同的地址空間中運行。如果您使用分布式 TensorFlow,則必須在與調用 tf.compat.v1.py_func() 的程序相同的進程中運行 tf.distribute.Server,並且必須將創建的操作固定到該服務器中的設備(例如,使用 with tf.device(): )。

注意:它產生未知形狀和等級的張量,因為形狀推斷不適用於任意 Python 代碼。如果您需要形狀,則需要根據靜態可用信息進行設置。

例如:

import tensorflow as tf
  import numpy as np

  def make_synthetic_data(i):
      return np.cast[np.uint8](i) * np.ones([20,256,256,3],
              dtype=np.float32) / 10.

  def preprocess_fn(i):
      ones = tf.py_function(make_synthetic_data,[i],tf.float32)
      ones.set_shape(tf.TensorShape([None, None, None, None]))
      ones = tf.image.resize(ones, [224,224])
      return ones

  ds = tf.data.Dataset.range(10)
  ds = ds.map(preprocess_fn)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.py_func。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。