當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.function用法及代碼示例


將函數編譯為可調用的 TensorFlow 圖。 (不推薦使用的參數)

用法

tf.function(
    func=None, input_signature=None, autograph=True, jit_compile=None,
    experimental_implements=None, experimental_autograph_options=None,
    experimental_relax_shapes=False, experimental_compile=None,
    experimental_follow_type_hints=None
) -> tf.types.experimental.GenericFunction

參數

  • func 要編譯的函數。如果 func 為 None,則 tf.function 返回一個可以使用單個參數調用的裝飾器 - func 。換句話說, tf.function(input_signature=...)(func) 等價於 tf.function(func, input_signature=...) 。前者可以用作裝飾器。
  • input_signature tf.TensorSpec 對象的可能嵌套序列,指定將提供給此函數的張量的形狀和數據類型。如果 None ,則為每個推斷的輸入簽名實例化一個單獨的函數。如果指定了 input_signature,則 func 的每個輸入都必須是 Tensor ,並且 func 不能接受 **kwargs
  • autograph 是否應申請親筆簽名func在跟蹤圖表之前。 Data-dependent 控製流需要autograph=True.有關詳細信息,請參閱tf.function 和 AutoGraph 指南.
  • jit_compile 如果True, 編譯函數使用XLA. XLA 執行編譯器優化,例如融合,並嘗試生成更高效的代碼。這可能會大大提高性能。如果設置為True, 整個函數需要由 XLA 編譯,或者tf.errors.InvalidArgumentError被拋出。如果None(默認),在 TPU 上運行時使用 XLA 編譯函數,在其他設備上運行時通過常規函數執行路徑。如果False, 在沒有 XLA 編譯的情況下執行函數。將此值設置為False在 TPU 上直接運行多設備函數時(例如,兩個 TPU 內核、一個 TPU 內核及其主機 CPU)。並非所有函數都可編譯,請參閱列表尖角.
  • experimental_implements 如果提供,則包含此實現的"known" 函數的名稱。例如"mycompany.my_recurrent_cell"。這作為屬性存儲在推理函數中,然後可以在處理序列化函數時檢測到。看標準化複合操作
    詳情。有關使用此屬性的示例,請參閱此例子上麵的代碼會自動檢測並替換實現"embedded_matmul" 的函數,並允許 TFLite 替換它自己的實現。例如,一個 tensorflow 用戶可以使用這個屬性來標記他們的函數也實現了embedded_matmul(也許更有效!)通過使用此參數指定它:@tf.function(experimental_implements="embedded_matmul")這可以指定為函數的字符串名稱,也可以指定為與函數名稱關聯的鍵值屬性列表對應的 NameAttrList。該函數的名稱將在 NameAttrList 的 'name' 字段中。要為此函數實現定義正式的 TF 操作,請嘗試實驗複合TF項目。
  • experimental_autograph_options tf.autograph.experimental.Feature 值的可選元組。
  • experimental_relax_shapes 當為 True 時,tf.function 可能會生成較少的、不太專門針對輸入形狀的圖形。
  • experimental_compile 已棄用 'jit_compile' 的別名。
  • experimental_follow_type_hints 當為 True 時,該函數可以使用來自 func 的類型注釋來優化跟蹤性能。例如,帶有tf.Tensor 注釋的參數將自動轉換為張量。

返回

拋出

  • ValueError 嘗試使用 jit_compile=True 時,但 XLA 支持不可用。

警告:不推薦使用某些參數:(experimental_compile)。它們將在未來的版本中被刪除。更新說明:experimental_compile 已棄用,請改用jit_compile

tf.function 構造一個 tf.types.experimental.GenericFunction,它執行由 trace-compiling 在 func 中的 TensorFlow 操作創建的 TensorFlow 圖 (tf.Graph)。有關該主題的更多信息,請參閱圖表和 tf.function 簡介。

有關性能和已知限製的提示,請參閱使用 tf.function 獲得更好的性能。

示例用法:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor:... numpy=array([7, 7], ...)>

trace-compilation 允許在特殊條件下執行非 TensorFlow 操作。通常,隻要調用 GenericFunction,就隻能保證 TensorFlow 操作運行並創建新結果。

特征

func可以使用data-dependent控製流,包括if , for , whilebreak , continuereturn語句:

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor:... numpy=1>

func 的閉包可能包括 tf.Tensortf.Variable 對象:

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor:... numpy=array([7, 7], ...)>

func 也可能使用帶有副作用的操作,例如 tf.printtf.Variable 和其他:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>

重要的:當跟蹤 func 時,任何 Python side-effects(附加到列表、使用 print 打印等)隻會發生一次。要將side-effects 執行到您的tf.function 中,需要將它們編寫為 TF ops:

l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

相反,請使用 TensorFlow 集合,例如 tf.TensorArray

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor:..., numpy=array([2, 3, 4], ...)>

tf.function 創建多態可調用對象

在內部,tf.types.experimental.GenericFunction 可能包含多個 tf.types.experimental.ConcreteFunction ,每個專用於具有不同數據類型或形狀的參數,因為 TensorFlow 可以對特定形狀、dtypes 和常量參數值的圖執行更多優化。 tf.function 將任何純 Python 值視為不透明對象(最好將其視為編譯時常量),並為其遇到的每組 Python 參數構建一個單獨的 tf.Graph。有關更多信息,請參閱 tf.function 指南

執行 GenericFunction 將根據參數類型和值選擇並執行適當的 ConcreteFunction

要獲得單獨的 ConcreteFunction ,請使用 GenericFunction.get_concrete_function 方法。可以使用與 func 相同的參數調用它並返回 tf.types.experimental.ConcreteFunctionConcreteFunction 由單個 tf.Graph 支持:

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True

ConcreteFunction s 可以像 GenericFunction s 一樣執行,但是它們的輸入受限於它們所專門化的類型。

回溯

ConcreteFunctions 是動態構建(跟蹤)的,因為 GenericFunction 是使用新的 TensorFlow 類型或形狀或使用新的 Python 值作為參數調用的。當GenericFunction 建立新的跟蹤時,就說func 被回溯。回溯是tf.function 的常見性能問題,因為它可能比執行已被追蹤的圖要慢得多。盡量減少代碼中的回溯量是理想的。

警告:將 python 標量或列表作為參數傳遞給tf.function 通常會回溯。為避免這種情況,請盡可能將數字參數作為張量傳遞:

@tf.function
def f(x):
  return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)  # Slow - compiles new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
f1 is f2
True

Python 數值參數僅應在它們采用少量不同值時使用,例如神經網絡中的層數等超參數。

輸入簽名

對於張量參數,GenericFunction 為每組唯一的輸入形狀和數據類型創建一個新的 ConcreteFunction。下麵的示例創建了兩個單獨的 ConcreteFunction ,每個都專門用於不同的形狀:

@tf.function
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False

可以選擇向tf.function 提供"input signature" 以控製此過程。輸入簽名使用tf.TensorSpec 對象指定函數的每個張量參數的形狀和類型。可以使用更一般的形狀。這可確保僅創建一個 ConcreteFunction,並將 GenericFunction 限製為指定的形狀和類型。當張量具有動態形狀時,這是一種限製回溯的有效方法。

@tf.function(
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True

變量隻能創建一次

tf.function 僅允許在第一次調用時創建新的 tf.Variable 對象:

class MyModule(tf.Module):
  def __init__(self):
    self.v = None

  @tf.function
  def __call__(self, x):
    if self.v is None:
      self.v = tf.Variable(tf.ones_like(x))
    return self.v * x

通常,建議在 tf.function 之外創建 tf.Variable 。在簡單的情況下,跨越tf.function 邊界的持久狀態可以使用純函數樣式實現,其中狀態由作為參數傳遞並作為返回值返回的tf.Tensor 表示。

對比以下兩種風格:

state = tf.Variable(1)
@tf.function
def f(x):
  state.assign_add(x)
f(tf.constant(2))  # Non-pure functional style
state
<tf.Variable ... numpy=3>
state = tf.constant(1)
@tf.function
def f(state, x):
  state += x
  return state
state = f(state, tf.constant(2))  # Pure functional style
state
<tf.Tensor:... numpy=3>

Python 操作每次跟蹤隻執行一次

func 可能包含混合了純 Python 操作的 TensorFlow 操作。但是,當函數執行時,隻會運行 TensorFlow 操作。 Python 操作僅在跟蹤時運行一次。如果 TensorFlow 操作依賴於 Pyhton 操作的結果,這些結果將被凍結到圖中。

@tf.function
def f(a, b):
  print('this runs at trace time; a is', a, 'and b is', b)
  return b
f(1, tf.constant(1))
this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(1, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f(2, tf.constant(1))
this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(2, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>

使用類型注釋來提高性能

'experimental_follow_type_hints` 可以與類型注釋一起使用,通過自動將任何 Python 值強製轉換為 tf.Tensor 來減少回溯(默認情況下不會這樣做,除非您使用輸入簽名)。

@tf.function(experimental_follow_type_hints=True)
def f_with_hints(x:tf.Tensor):
  print('Tracing')
  return x
@tf.function(experimental_follow_type_hints=False)
def f_no_hints(x:tf.Tensor):
  print('Tracing')
  return x
f_no_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_no_hints(2)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f_with_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_with_hints(2)
<tf.Tensor:shape=(), dtype=int32, numpy=2>

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.function。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。