將函數編譯為可調用的 TensorFlow 圖。 (不推薦使用的參數)
用法
tf.function(
func=None, input_signature=None, autograph=True, jit_compile=None,
experimental_implements=None, experimental_autograph_options=None,
experimental_relax_shapes=False, experimental_compile=None,
experimental_follow_type_hints=None
) -> tf.types.experimental.GenericFunction
參數
-
func
要編譯的函數。如果func
為 None,則tf.function
返回一個可以使用單個參數調用的裝飾器 -func
。換句話說,tf.function(input_signature=...)(func)
等價於tf.function(func, input_signature=...)
。前者可以用作裝飾器。 -
input_signature
tf.TensorSpec
對象的可能嵌套序列,指定將提供給此函數的張量的形狀和數據類型。如果None
,則為每個推斷的輸入簽名實例化一個單獨的函數。如果指定了 input_signature,則func
的每個輸入都必須是Tensor
,並且func
不能接受**kwargs
。 -
autograph
是否應申請親筆簽名func
在跟蹤圖表之前。 Data-dependent 控製流需要autograph=True
.有關詳細信息,請參閱tf.function 和 AutoGraph 指南. -
jit_compile
如果True
, 編譯函數使用XLA. XLA 執行編譯器優化,例如融合,並嘗試生成更高效的代碼。這可能會大大提高性能。如果設置為True
, 整個函數需要由 XLA 編譯,或者tf.errors.InvalidArgumentError被拋出。如果None
(默認),在 TPU 上運行時使用 XLA 編譯函數,在其他設備上運行時通過常規函數執行路徑。如果False
, 在沒有 XLA 編譯的情況下執行函數。將此值設置為False
在 TPU 上直接運行多設備函數時(例如,兩個 TPU 內核、一個 TPU 內核及其主機 CPU)。並非所有函數都可編譯,請參閱列表尖角. -
experimental_implements
如果提供,則包含此實現的"known" 函數的名稱。例如"mycompany.my_recurrent_cell"。這作為屬性存儲在推理函數中,然後可以在處理序列化函數時檢測到。看標準化複合操作
詳情。有關使用此屬性的示例,請參閱此例子上麵的代碼會自動檢測並替換實現"embedded_matmul" 的函數,並允許 TFLite 替換它自己的實現。例如,一個 tensorflow 用戶可以使用這個屬性來標記他們的函數也實現了embedded_matmul
(也許更有效!)通過使用此參數指定它:@tf.function(experimental_implements="embedded_matmul")
這可以指定為函數的字符串名稱,也可以指定為與函數名稱關聯的鍵值屬性列表對應的 NameAttrList。該函數的名稱將在 NameAttrList 的 'name' 字段中。要為此函數實現定義正式的 TF 操作,請嘗試實驗複合TF項目。 -
experimental_autograph_options
tf.autograph.experimental.Feature
值的可選元組。 -
experimental_relax_shapes
當為 True 時,tf.function
可能會生成較少的、不太專門針對輸入形狀的圖形。 -
experimental_compile
已棄用 'jit_compile' 的別名。 -
experimental_follow_type_hints
當為 True 時,該函數可以使用來自func
的類型注釋來優化跟蹤性能。例如,帶有tf.Tensor
注釋的參數將自動轉換為張量。
返回
-
如果
func
不是 None,則返回tf.types.experimental.GenericFunction
。如果func
為 None,則返回一個裝飾器,當使用單個func
參數調用該裝飾器時,返回一個tf.types.experimental.GenericFunction
。
拋出
-
ValueError
嘗試使用jit_compile=True
時,但 XLA 支持不可用。
警告:不推薦使用某些參數:(experimental_compile)
。它們將在未來的版本中被刪除。更新說明:experimental_compile 已棄用,請改用jit_compile
tf.function
構造一個 tf.types.experimental.GenericFunction
,它執行由 trace-compiling 在 func
中的 TensorFlow 操作創建的 TensorFlow 圖 (tf.Graph
)。有關該主題的更多信息,請參閱圖表和 tf.function 簡介。
有關性能和已知限製的提示,請參閱使用 tf.function 獲得更好的性能。
示例用法:
@tf.function
def f(x, y):
return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor:... numpy=array([7, 7], ...)>
trace-compilation 允許在特殊條件下執行非 TensorFlow 操作。通常,隻要調用 GenericFunction
,就隻能保證 TensorFlow 操作運行並創建新結果。
特征
func
可以使用data-dependent控製流,包括if
, for
, while
break
, continue
和return
語句:
@tf.function
def f(x):
if tf.reduce_sum(x) > 0:
return x * x
else:
return -x // 2
f(tf.constant(-2))
<tf.Tensor:... numpy=1>
func
的閉包可能包括 tf.Tensor
和 tf.Variable
對象:
@tf.function
def f():
return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor:... numpy=array([7, 7], ...)>
func
也可能使用帶有副作用的操作,例如 tf.print
、 tf.Variable
和其他:
v = tf.Variable(1)
@tf.function
def f(x):
for i in tf.range(x):
v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
重要的:當跟蹤 func
時,任何 Python side-effects(附加到列表、使用 print
打印等)隻會發生一次。要將side-effects 執行到您的tf.function
中,需要將它們編寫為 TF ops:
l = []
@tf.function
def f(x):
for i in x:
l.append(i + 1) # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]
相反,請使用 TensorFlow 集合,例如 tf.TensorArray
:
@tf.function
def f(x):
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
for i in range(len(x)):
ta = ta.write(i, x[i] + 1)
return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor:..., numpy=array([2, 3, 4], ...)>
tf.function
創建多態可調用對象
在內部,tf.types.experimental.GenericFunction
可能包含多個 tf.types.experimental.ConcreteFunction
,每個專用於具有不同數據類型或形狀的參數,因為 TensorFlow 可以對特定形狀、dtypes 和常量參數值的圖執行更多優化。 tf.function
將任何純 Python 值視為不透明對象(最好將其視為編譯時常量),並為其遇到的每組 Python 參數構建一個單獨的 tf.Graph
。有關更多信息,請參閱 tf.function 指南
執行 GenericFunction
將根據參數類型和值選擇並執行適當的 ConcreteFunction
。
要獲得單獨的 ConcreteFunction
,請使用 GenericFunction.get_concrete_function
方法。可以使用與 func
相同的參數調用它並返回 tf.types.experimental.ConcreteFunction
。 ConcreteFunction
由單個 tf.Graph
支持:
@tf.function
def f(x):
return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
ConcreteFunction
s 可以像 GenericFunction
s 一樣執行,但是它們的輸入受限於它們所專門化的類型。
回溯
ConcreteFunctions
是動態構建(跟蹤)的,因為 GenericFunction
是使用新的 TensorFlow 類型或形狀或使用新的 Python 值作為參數調用的。當GenericFunction
建立新的跟蹤時,就說func
被回溯。回溯是tf.function
的常見性能問題,因為它可能比執行已被追蹤的圖要慢得多。盡量減少代碼中的回溯量是理想的。
警告:將 python 標量或列表作為參數傳遞給tf.function
通常會回溯。為避免這種情況,請盡可能將數字參數作為張量傳遞:
@tf.function
def f(x):
return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2) # Slow - compiles new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1
f1 is f2
True
Python 數值參數僅應在它們采用少量不同值時使用,例如神經網絡中的層數等超參數。
輸入簽名
對於張量參數,GenericFunction
為每組唯一的輸入形狀和數據類型創建一個新的 ConcreteFunction
。下麵的示例創建了兩個單獨的 ConcreteFunction
,每個都專門用於不同的形狀:
@tf.function
def f(x):
return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False
可以選擇向tf.function
提供"input signature" 以控製此過程。輸入簽名使用tf.TensorSpec
對象指定函數的每個張量參數的形狀和類型。可以使用更一般的形狀。這可確保僅創建一個 ConcreteFunction
,並將 GenericFunction
限製為指定的形狀和類型。當張量具有動態形狀時,這是一種限製回溯的有效方法。
@tf.function(
input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True
變量隻能創建一次
tf.function
僅允許在第一次調用時創建新的 tf.Variable
對象:
class MyModule(tf.Module):
def __init__(self):
self.v = None
@tf.function
def __call__(self, x):
if self.v is None:
self.v = tf.Variable(tf.ones_like(x))
return self.v * x
通常,建議在 tf.function
之外創建 tf.Variable
。在簡單的情況下,跨越tf.function
邊界的持久狀態可以使用純函數樣式實現,其中狀態由作為參數傳遞並作為返回值返回的tf.Tensor
表示。
對比以下兩種風格:
state = tf.Variable(1)
@tf.function
def f(x):
state.assign_add(x)
f(tf.constant(2)) # Non-pure functional style
state
<tf.Variable ... numpy=3>
state = tf.constant(1)
@tf.function
def f(state, x):
state += x
return state
state = f(state, tf.constant(2)) # Pure functional style
state
<tf.Tensor:... numpy=3>
Python 操作每次跟蹤隻執行一次
func
可能包含混合了純 Python 操作的 TensorFlow 操作。但是,當函數執行時,隻會運行 TensorFlow 操作。 Python 操作僅在跟蹤時運行一次。如果 TensorFlow 操作依賴於 Pyhton 操作的結果,這些結果將被凍結到圖中。
@tf.function
def f(a, b):
print('this runs at trace time; a is', a, 'and b is', b)
return b
f(1, tf.constant(1))
this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(1, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f(2, tf.constant(1))
this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(2, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>
使用類型注釋來提高性能
'experimental_follow_type_hints` 可以與類型注釋一起使用,通過自動將任何 Python 值強製轉換為 tf.Tensor
來減少回溯(默認情況下不會這樣做,除非您使用輸入簽名)。
@tf.function(experimental_follow_type_hints=True)
def f_with_hints(x:tf.Tensor):
print('Tracing')
return x
@tf.function(experimental_follow_type_hints=False)
def f_no_hints(x:tf.Tensor):
print('Tracing')
return x
f_no_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_no_hints(2)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f_with_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_with_hints(2)
<tf.Tensor:shape=(), dtype=int32, numpy=2>
相關用法
- Python tf.feature_column.crossed_column用法及代碼示例
- Python tf.feature_column.sequence_categorical_column_with_identity用法及代碼示例
- Python tf.feature_column.categorical_column_with_vocabulary_list用法及代碼示例
- Python tf.feature_column.categorical_column_with_hash_bucket用法及代碼示例
- Python tf.feature_column.bucketized_column用法及代碼示例
- Python tf.feature_column.categorical_column_with_identity用法及代碼示例
- Python tf.fingerprint用法及代碼示例
- Python tf.feature_column.sequence_numeric_column用法及代碼示例
- Python tf.feature_column.sequence_categorical_column_with_vocabulary_file用法及代碼示例
- Python tf.feature_column.sequence_categorical_column_with_vocabulary_list用法及代碼示例
- Python tf.feature_column.sequence_categorical_column_with_hash_bucket用法及代碼示例
- Python tf.foldl用法及代碼示例
- Python tf.feature_column.shared_embeddings用法及代碼示例
- Python tf.feature_column.categorical_column_with_vocabulary_file用法及代碼示例
- Python tf.feature_column.indicator_column用法及代碼示例
- Python tf.feature_column.weighted_categorical_column用法及代碼示例
- Python tf.foldr用法及代碼示例
- Python tf.feature_column.numeric_column用法及代碼示例
- Python tf.feature_column.embedding_column用法及代碼示例
- Python tf.fill用法及代碼示例
注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。