包裝一個 python 函數並將其用作 TensorFlow 操作。
用法
tf.compat.v1.py_func(
func, inp, Tout, stateful=True, name=None
)
參數
-
func
一個 Python 函數,它接受ndarray
對象作為參數並返回ndarray
對象列表(或單個ndarray
)。此函數必須接受與inp
中的張量一樣多的參數,並且這些參數類型將匹配inp
中相應的tf.Tensor
對象。返回的ndarray
必須與定義的數量和類型相匹配Tout
。重要提示:func
的輸入和輸出 numpyndarray
不保證是副本。在某些情況下,它們的底層內存將與相應的 TensorFlow 張量共享。在沒有顯式 (np.) 複製的情況下,在 python 數據結構中就地修改或存儲func
輸入或返回值可能會產生不確定的後果。 -
inp
Tensor
對象的列表。 -
Tout
tensorflow 數據類型的列表或元組或單個 tensorflow 數據類型(如果隻有一個),指示func
返回的內容。 -
stateful
(布爾值。)如果為 True,則該函數應被視為有狀態。如果一個函數是無狀態的,當給定相同的輸入時,它將返回相同的輸出並且沒有可觀察到的副作用。諸如公共子表達式消除之類的優化僅在無狀態操作上執行。 -
name
操作的名稱(可選)。
返回
-
Tensor
的列表或func
計算的單個Tensor
。
遷移到 TF2
警告:這個 API 是為 TensorFlow v1 設計的。繼續閱讀有關如何從該 API 遷移到本機 TensorFlow v2 等效項的詳細信息。見TensorFlow v1 到 TensorFlow v2 遷移指南有關如何遷移其餘代碼的說明。
此名稱在 TF2 中已棄用並刪除,但 tf.numpy_function
是 near-exact 替換,隻需刪除 stateful
參數(所有 tf.numpy_function
調用都被視為有狀態)。它與即刻執行和 tf.function
兼容。
tf.py_function
是一個接近但不是完全替代品,將 TensorFlow 張量傳遞給包裝函數而不是 NumPy 數組,後者提供梯度並可以利用加速器。
前:
def fn_using_numpy(x):
x[0] = 0.
return x
tf.compat.v1.py_func(fn_using_numpy, inp=[tf.constant([1., 2.])],
Tout=tf.float32, stateful=False)
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>
後:
tf.numpy_function(fn_using_numpy, inp=[tf.constant([1., 2.])],
Tout=tf.float32)
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>
給定一個 python 函數 func
,它將 numpy 數組作為其參數並返回 numpy 數組作為其輸出,將此函數包裝為 TensorFlow 圖中的一個操作。以下代碼段構建了一個簡單的 TensorFlow 圖,該圖調用 np.sinh()
NumPy 函數作為圖中的操作:
def my_func(x):
# x will be a numpy array with the contents of the placeholder below
return np.sinh(x)
input = tf.compat.v1.placeholder(tf.float32)
y = tf.compat.v1.py_func(my_func, [input], tf.float32)
注意:tf.compat.v1.py_func()
操作具有以下已知限製:
函數體(即
func
)不會在GraphDef
中序列化。因此,如果您需要序列化模型並在不同的環境中恢複它,則不應使用此函數。該操作必須在與調用
tf.compat.v1.py_func()
的 Python 程序相同的地址空間中運行。如果您使用分布式 TensorFlow,則必須在與調用tf.compat.v1.py_func()
的程序相同的進程中運行tf.distribute.Server
,並且必須將創建的操作固定到該服務器中的設備(例如,使用with tf.device():
)。
注意:它產生未知形狀和等級的張量,因為形狀推斷不適用於任意 Python 代碼。如果您需要形狀,則需要根據靜態可用信息進行設置。
例如:
import tensorflow as tf
import numpy as np
def make_synthetic_data(i):
return np.cast[np.uint8](i) * np.ones([20,256,256,3],
dtype=np.float32) / 10.
def preprocess_fn(i):
ones = tf.py_function(make_synthetic_data,[i],tf.float32)
ones.set_shape(tf.TensorShape([None, None, None, None]))
ones = tf.image.resize(ones, [224,224])
return ones
ds = tf.data.Dataset.range(10)
ds = ds.map(preprocess_fn)
相關用法
- Python tf.compat.v1.placeholder用法及代碼示例
- Python tf.compat.v1.placeholder_with_default用法及代碼示例
- Python tf.compat.v1.profiler.Profiler用法及代碼示例
- Python tf.compat.v1.parse_example用法及代碼示例
- Python tf.compat.v1.pad用法及代碼示例
- Python tf.compat.v1.profiler.ProfileOptionBuilder用法及代碼示例
- Python tf.compat.v1.distributions.Multinomial.stddev用法及代碼示例
- Python tf.compat.v1.distribute.MirroredStrategy.experimental_distribute_dataset用法及代碼示例
- Python tf.compat.v1.data.TFRecordDataset.interleave用法及代碼示例
- Python tf.compat.v1.distributions.Bernoulli.cross_entropy用法及代碼示例
- Python tf.compat.v1.Variable.eval用法及代碼示例
- Python tf.compat.v1.train.FtrlOptimizer.compute_gradients用法及代碼示例
- Python tf.compat.v1.layers.conv3d用法及代碼示例
- Python tf.compat.v1.strings.length用法及代碼示例
- Python tf.compat.v1.data.Dataset.snapshot用法及代碼示例
- Python tf.compat.v1.data.experimental.SqlDataset.reduce用法及代碼示例
- Python tf.compat.v1.feature_column.categorical_column_with_vocabulary_file用法及代碼示例
- Python tf.compat.v1.data.TextLineDataset.from_tensors用法及代碼示例
- Python tf.compat.v1.variable_scope用法及代碼示例
- Python tf.compat.v1.data.experimental.SqlDataset.as_numpy_iterator用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.py_func。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。