将函数编译为可调用的 TensorFlow 图。 (不推荐使用的参数)
用法
tf.function(
func=None, input_signature=None, autograph=True, jit_compile=None,
experimental_implements=None, experimental_autograph_options=None,
experimental_relax_shapes=False, experimental_compile=None,
experimental_follow_type_hints=None
) -> tf.types.experimental.GenericFunction
参数
-
func
要编译的函数。如果func
为 None,则tf.function
返回一个可以使用单个参数调用的装饰器 -func
。换句话说,tf.function(input_signature=...)(func)
等价于tf.function(func, input_signature=...)
。前者可以用作装饰器。 -
input_signature
tf.TensorSpec
对象的可能嵌套序列,指定将提供给此函数的张量的形状和数据类型。如果None
,则为每个推断的输入签名实例化一个单独的函数。如果指定了 input_signature,则func
的每个输入都必须是Tensor
,并且func
不能接受**kwargs
。 -
autograph
是否应申请亲笔签名func
在跟踪图表之前。 Data-dependent 控制流需要autograph=True
.有关详细信息,请参阅tf.function 和 AutoGraph 指南. -
jit_compile
如果True
, 编译函数使用XLA. XLA 执行编译器优化,例如融合,并尝试生成更高效的代码。这可能会大大提高性能。如果设置为True
, 整个函数需要由 XLA 编译,或者tf.errors.InvalidArgumentError被抛出。如果None
(默认),在 TPU 上运行时使用 XLA 编译函数,在其他设备上运行时通过常规函数执行路径。如果False
, 在没有 XLA 编译的情况下执行函数。将此值设置为False
在 TPU 上直接运行多设备函数时(例如,两个 TPU 内核、一个 TPU 内核及其主机 CPU)。并非所有函数都可编译,请参阅列表尖角. -
experimental_implements
如果提供,则包含此实现的"known" 函数的名称。例如"mycompany.my_recurrent_cell"。这作为属性存储在推理函数中,然后可以在处理序列化函数时检测到。看标准化复合操作
详情。有关使用此属性的示例,请参阅此例子上面的代码会自动检测并替换实现"embedded_matmul" 的函数,并允许 TFLite 替换它自己的实现。例如,一个 tensorflow 用户可以使用这个属性来标记他们的函数也实现了embedded_matmul
(也许更有效!)通过使用此参数指定它:@tf.function(experimental_implements="embedded_matmul")
这可以指定为函数的字符串名称,也可以指定为与函数名称关联的键值属性列表对应的 NameAttrList。该函数的名称将在 NameAttrList 的 'name' 字段中。要为此函数实现定义正式的 TF 操作,请尝试实验复合TF项目。 -
experimental_autograph_options
tf.autograph.experimental.Feature
值的可选元组。 -
experimental_relax_shapes
当为 True 时,tf.function
可能会生成较少的、不太专门针对输入形状的图形。 -
experimental_compile
已弃用 'jit_compile' 的别名。 -
experimental_follow_type_hints
当为 True 时,该函数可以使用来自func
的类型注释来优化跟踪性能。例如,带有tf.Tensor
注释的参数将自动转换为张量。
返回
-
如果
func
不是 None,则返回tf.types.experimental.GenericFunction
。如果func
为 None,则返回一个装饰器,当使用单个func
参数调用该装饰器时,返回一个tf.types.experimental.GenericFunction
。
抛出
-
ValueError
尝试使用jit_compile=True
时,但 XLA 支持不可用。
警告:不推荐使用某些参数:(experimental_compile)
。它们将在未来的版本中被删除。更新说明:experimental_compile 已弃用,请改用jit_compile
tf.function
构造一个 tf.types.experimental.GenericFunction
,它执行由 trace-compiling 在 func
中的 TensorFlow 操作创建的 TensorFlow 图 (tf.Graph
)。有关该主题的更多信息,请参阅图表和 tf.function 简介。
有关性能和已知限制的提示,请参阅使用 tf.function 获得更好的性能。
示例用法:
@tf.function
def f(x, y):
return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor:... numpy=array([7, 7], ...)>
trace-compilation 允许在特殊条件下执行非 TensorFlow 操作。通常,只要调用 GenericFunction
,就只能保证 TensorFlow 操作运行并创建新结果。
特征
func
可以使用data-dependent控制流,包括if
, for
, while
break
, continue
和return
语句:
@tf.function
def f(x):
if tf.reduce_sum(x) > 0:
return x * x
else:
return -x // 2
f(tf.constant(-2))
<tf.Tensor:... numpy=1>
func
的闭包可能包括 tf.Tensor
和 tf.Variable
对象:
@tf.function
def f():
return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor:... numpy=array([7, 7], ...)>
func
也可能使用带有副作用的操作,例如 tf.print
、 tf.Variable
和其他:
v = tf.Variable(1)
@tf.function
def f(x):
for i in tf.range(x):
v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>
重要的:当跟踪 func
时,任何 Python side-effects(附加到列表、使用 print
打印等)只会发生一次。要将side-effects 执行到您的tf.function
中,需要将它们编写为 TF ops:
l = []
@tf.function
def f(x):
for i in x:
l.append(i + 1) # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]
相反,请使用 TensorFlow 集合,例如 tf.TensorArray
:
@tf.function
def f(x):
ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
for i in range(len(x)):
ta = ta.write(i, x[i] + 1)
return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor:..., numpy=array([2, 3, 4], ...)>
tf.function
创建多态可调用对象
在内部,tf.types.experimental.GenericFunction
可能包含多个 tf.types.experimental.ConcreteFunction
,每个专用于具有不同数据类型或形状的参数,因为 TensorFlow 可以对特定形状、dtypes 和常量参数值的图执行更多优化。 tf.function
将任何纯 Python 值视为不透明对象(最好将其视为编译时常量),并为其遇到的每组 Python 参数构建一个单独的 tf.Graph
。有关更多信息,请参阅 tf.function 指南
执行 GenericFunction
将根据参数类型和值选择并执行适当的 ConcreteFunction
。
要获得单独的 ConcreteFunction
,请使用 GenericFunction.get_concrete_function
方法。可以使用与 func
相同的参数调用它并返回 tf.types.experimental.ConcreteFunction
。 ConcreteFunction
由单个 tf.Graph
支持:
@tf.function
def f(x):
return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
ConcreteFunction
s 可以像 GenericFunction
s 一样执行,但是它们的输入受限于它们所专门化的类型。
回溯
ConcreteFunctions
是动态构建(跟踪)的,因为 GenericFunction
是使用新的 TensorFlow 类型或形状或使用新的 Python 值作为参数调用的。当GenericFunction
建立新的跟踪时,就说func
被回溯。回溯是tf.function
的常见性能问题,因为它可能比执行已被追踪的图要慢得多。尽量减少代码中的回溯量是理想的。
警告:将 python 标量或列表作为参数传递给tf.function
通常会回溯。为避免这种情况,请尽可能将数字参数作为张量传递:
@tf.function
def f(x):
return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2) # Slow - compiles new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1
f1 is f2
True
Python 数值参数仅应在它们采用少量不同值时使用,例如神经网络中的层数等超参数。
输入签名
对于张量参数,GenericFunction
为每组唯一的输入形状和数据类型创建一个新的 ConcreteFunction
。下面的示例创建了两个单独的 ConcreteFunction
,每个都专门用于不同的形状:
@tf.function
def f(x):
return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False
可以选择向tf.function
提供"input signature" 以控制此过程。输入签名使用tf.TensorSpec
对象指定函数的每个张量参数的形状和类型。可以使用更一般的形状。这可确保仅创建一个 ConcreteFunction
,并将 GenericFunction
限制为指定的形状和类型。当张量具有动态形状时,这是一种限制回溯的有效方法。
@tf.function(
input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True
变量只能创建一次
tf.function
仅允许在第一次调用时创建新的 tf.Variable
对象:
class MyModule(tf.Module):
def __init__(self):
self.v = None
@tf.function
def __call__(self, x):
if self.v is None:
self.v = tf.Variable(tf.ones_like(x))
return self.v * x
通常,建议在 tf.function
之外创建 tf.Variable
。在简单的情况下,跨越tf.function
边界的持久状态可以使用纯函数样式实现,其中状态由作为参数传递并作为返回值返回的tf.Tensor
表示。
对比以下两种风格:
state = tf.Variable(1)
@tf.function
def f(x):
state.assign_add(x)
f(tf.constant(2)) # Non-pure functional style
state
<tf.Variable ... numpy=3>
state = tf.constant(1)
@tf.function
def f(state, x):
state += x
return state
state = f(state, tf.constant(2)) # Pure functional style
state
<tf.Tensor:... numpy=3>
Python 操作每次跟踪只执行一次
func
可能包含混合了纯 Python 操作的 TensorFlow 操作。但是,当函数执行时,只会运行 TensorFlow 操作。 Python 操作仅在跟踪时运行一次。如果 TensorFlow 操作依赖于 Pyhton 操作的结果,这些结果将被冻结到图中。
@tf.function
def f(a, b):
print('this runs at trace time; a is', a, 'and b is', b)
return b
f(1, tf.constant(1))
this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(1, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f(2, tf.constant(1))
this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(2, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>
使用类型注释来提高性能
'experimental_follow_type_hints` 可以与类型注释一起使用,通过自动将任何 Python 值强制转换为 tf.Tensor
来减少回溯(默认情况下不会这样做,除非您使用输入签名)。
@tf.function(experimental_follow_type_hints=True)
def f_with_hints(x:tf.Tensor):
print('Tracing')
return x
@tf.function(experimental_follow_type_hints=False)
def f_no_hints(x:tf.Tensor):
print('Tracing')
return x
f_no_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_no_hints(2)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f_with_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_with_hints(2)
<tf.Tensor:shape=(), dtype=int32, numpy=2>
相关用法
- Python tf.feature_column.crossed_column用法及代码示例
- Python tf.feature_column.sequence_categorical_column_with_identity用法及代码示例
- Python tf.feature_column.categorical_column_with_vocabulary_list用法及代码示例
- Python tf.feature_column.categorical_column_with_hash_bucket用法及代码示例
- Python tf.feature_column.bucketized_column用法及代码示例
- Python tf.feature_column.categorical_column_with_identity用法及代码示例
- Python tf.fingerprint用法及代码示例
- Python tf.feature_column.sequence_numeric_column用法及代码示例
- Python tf.feature_column.sequence_categorical_column_with_vocabulary_file用法及代码示例
- Python tf.feature_column.sequence_categorical_column_with_vocabulary_list用法及代码示例
- Python tf.feature_column.sequence_categorical_column_with_hash_bucket用法及代码示例
- Python tf.foldl用法及代码示例
- Python tf.feature_column.shared_embeddings用法及代码示例
- Python tf.feature_column.categorical_column_with_vocabulary_file用法及代码示例
- Python tf.feature_column.indicator_column用法及代码示例
- Python tf.feature_column.weighted_categorical_column用法及代码示例
- Python tf.foldr用法及代码示例
- Python tf.feature_column.numeric_column用法及代码示例
- Python tf.feature_column.embedding_column用法及代码示例
- Python tf.fill用法及代码示例
注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。