当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.function用法及代码示例


将函数编译为可调用的 TensorFlow 图。 (不推荐使用的参数)

用法

tf.function(
    func=None, input_signature=None, autograph=True, jit_compile=None,
    experimental_implements=None, experimental_autograph_options=None,
    experimental_relax_shapes=False, experimental_compile=None,
    experimental_follow_type_hints=None
) -> tf.types.experimental.GenericFunction

参数

  • func 要编译的函数。如果 func 为 None,则 tf.function 返回一个可以使用单个参数调用的装饰器 - func 。换句话说, tf.function(input_signature=...)(func) 等价于 tf.function(func, input_signature=...) 。前者可以用作装饰器。
  • input_signature tf.TensorSpec 对象的可能嵌套序列,指定将提供给此函数的张量的形状和数据类型。如果 None ,则为每个推断的输入签名实例化一个单独的函数。如果指定了 input_signature,则 func 的每个输入都必须是 Tensor ,并且 func 不能接受 **kwargs
  • autograph 是否应申请亲笔签名func在跟踪图表之前。 Data-dependent 控制流需要autograph=True.有关详细信息,请参阅tf.function 和 AutoGraph 指南.
  • jit_compile 如果True, 编译函数使用XLA. XLA 执行编译器优化,例如融合,并尝试生成更高效的代码。这可能会大大提高性能。如果设置为True, 整个函数需要由 XLA 编译,或者tf.errors.InvalidArgumentError被抛出。如果None(默认),在 TPU 上运行时使用 XLA 编译函数,在其他设备上运行时通过常规函数执行路径。如果False, 在没有 XLA 编译的情况下执行函数。将此值设置为False在 TPU 上直接运行多设备函数时(例如,两个 TPU 内核、一个 TPU 内核及其主机 CPU)。并非所有函数都可编译,请参阅列表尖角.
  • experimental_implements 如果提供,则包含此实现的"known" 函数的名称。例如"mycompany.my_recurrent_cell"。这作为属性存储在推理函数中,然后可以在处理序列化函数时检测到。看标准化复合操作
    详情。有关使用此属性的示例,请参阅此例子上面的代码会自动检测并替换实现"embedded_matmul" 的函数,并允许 TFLite 替换它自己的实现。例如,一个 tensorflow 用户可以使用这个属性来标记他们的函数也实现了embedded_matmul(也许更有效!)通过使用此参数指定它:@tf.function(experimental_implements="embedded_matmul")这可以指定为函数的字符串名称,也可以指定为与函数名称关联的键值属性列表对应的 NameAttrList。该函数的名称将在 NameAttrList 的 'name' 字段中。要为此函数实现定义正式的 TF 操作,请尝试实验复合TF项目。
  • experimental_autograph_options tf.autograph.experimental.Feature 值的可选元组。
  • experimental_relax_shapes 当为 True 时,tf.function 可能会生成较少的、不太专门针对输入形状的图形。
  • experimental_compile 已弃用 'jit_compile' 的别名。
  • experimental_follow_type_hints 当为 True 时,该函数可以使用来自 func 的类型注释来优化跟踪性能。例如,带有tf.Tensor 注释的参数将自动转换为张量。

返回

抛出

  • ValueError 尝试使用 jit_compile=True 时,但 XLA 支持不可用。

警告:不推荐使用某些参数:(experimental_compile)。它们将在未来的版本中被删除。更新说明:experimental_compile 已弃用,请改用jit_compile

tf.function 构造一个 tf.types.experimental.GenericFunction,它执行由 trace-compiling 在 func 中的 TensorFlow 操作创建的 TensorFlow 图 (tf.Graph)。有关该主题的更多信息,请参阅图表和 tf.function 简介。

有关性能和已知限制的提示,请参阅使用 tf.function 获得更好的性能。

示例用法:

@tf.function
def f(x, y):
  return x ** 2 + y
x = tf.constant([2, 3])
y = tf.constant([3, -2])
f(x, y)
<tf.Tensor:... numpy=array([7, 7], ...)>

trace-compilation 允许在特殊条件下执行非 TensorFlow 操作。通常,只要调用 GenericFunction,就只能保证 TensorFlow 操作运行并创建新结果。

特征

func可以使用data-dependent控制流,包括if , for , whilebreak , continuereturn语句:

@tf.function
def f(x):
  if tf.reduce_sum(x) > 0:
    return x * x
  else:
    return -x // 2
f(tf.constant(-2))
<tf.Tensor:... numpy=1>

func 的闭包可能包括 tf.Tensortf.Variable 对象:

@tf.function
def f():
  return x ** 2 + y
x = tf.constant([-2, -3])
y = tf.Variable([3, -2])
f()
<tf.Tensor:... numpy=array([7, 7], ...)>

func 也可能使用带有副作用的操作,例如 tf.printtf.Variable 和其他:

v = tf.Variable(1)
@tf.function
def f(x):
  for i in tf.range(x):
    v.assign_add(i)
f(3)
v
<tf.Variable ... numpy=4>

重要的:当跟踪 func 时,任何 Python side-effects(附加到列表、使用 print 打印等)只会发生一次。要将side-effects 执行到您的tf.function 中,需要将它们编写为 TF ops:

l = []
@tf.function
def f(x):
  for i in x:
    l.append(i + 1)    # Caution! Will only happen once when tracing
f(tf.constant([1, 2, 3]))
l
[<tf.Tensor ...>]

相反,请使用 TensorFlow 集合,例如 tf.TensorArray

@tf.function
def f(x):
  ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
  for i in range(len(x)):
    ta = ta.write(i, x[i] + 1)
  return ta.stack()
f(tf.constant([1, 2, 3]))
<tf.Tensor:..., numpy=array([2, 3, 4], ...)>

tf.function 创建多态可调用对象

在内部,tf.types.experimental.GenericFunction 可能包含多个 tf.types.experimental.ConcreteFunction ,每个专用于具有不同数据类型或形状的参数,因为 TensorFlow 可以对特定形状、dtypes 和常量参数值的图执行更多优化。 tf.function 将任何纯 Python 值视为不透明对象(最好将其视为编译时常量),并为其遇到的每组 Python 参数构建一个单独的 tf.Graph。有关更多信息,请参阅 tf.function 指南

执行 GenericFunction 将根据参数类型和值选择并执行适当的 ConcreteFunction

要获得单独的 ConcreteFunction ,请使用 GenericFunction.get_concrete_function 方法。可以使用与 func 相同的参数调用它并返回 tf.types.experimental.ConcreteFunctionConcreteFunction 由单个 tf.Graph 支持:

@tf.function
def f(x):
  return x + 1
isinstance(f.get_concrete_function(1).graph, tf.Graph)
True

ConcreteFunction s 可以像 GenericFunction s 一样执行,但是它们的输入受限于它们所专门化的类型。

回溯

ConcreteFunctions 是动态构建(跟踪)的,因为 GenericFunction 是使用新的 TensorFlow 类型或形状或使用新的 Python 值作为参数调用的。当GenericFunction 建立新的跟踪时,就说func 被回溯。回溯是tf.function 的常见性能问题,因为它可能比执行已被追踪的图要慢得多。尽量减少代码中的回溯量是理想的。

警告:将 python 标量或列表作为参数传递给tf.function 通常会回溯。为避免这种情况,请尽可能将数字参数作为张量传递:

@tf.function
def f(x):
  return tf.abs(x)
f1 = f.get_concrete_function(1)
f2 = f.get_concrete_function(2)  # Slow - compiles new graph
f1 is f2
False
f1 = f.get_concrete_function(tf.constant(1))
f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
f1 is f2
True

Python 数值参数仅应在它们采用少量不同值时使用,例如神经网络中的层数等超参数。

输入签名

对于张量参数,GenericFunction 为每组唯一的输入形状和数据类型创建一个新的 ConcreteFunction。下面的示例创建了两个单独的 ConcreteFunction ,每个都专门用于不同的形状:

@tf.function
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False

可以选择向tf.function 提供"input signature" 以控制此过程。输入签名使用tf.TensorSpec 对象指定函数的每个张量参数的形状和类型。可以使用更一般的形状。这可确保仅创建一个 ConcreteFunction,并将 GenericFunction 限制为指定的形状和类型。当张量具有动态形状时,这是一种限制回溯的有效方法。

@tf.function(
    input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
def f(x):
  return x + 1
vector = tf.constant([1.0, 1.0])
matrix = tf.constant([[3.0]])
f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True

变量只能创建一次

tf.function 仅允许在第一次调用时创建新的 tf.Variable 对象:

class MyModule(tf.Module):
  def __init__(self):
    self.v = None

  @tf.function
  def __call__(self, x):
    if self.v is None:
      self.v = tf.Variable(tf.ones_like(x))
    return self.v * x

通常,建议在 tf.function 之外创建 tf.Variable 。在简单的情况下,跨越tf.function 边界的持久状态可以使用纯函数样式实现,其中状态由作为参数传递并作为返回值返回的tf.Tensor 表示。

对比以下两种风格:

state = tf.Variable(1)
@tf.function
def f(x):
  state.assign_add(x)
f(tf.constant(2))  # Non-pure functional style
state
<tf.Variable ... numpy=3>
state = tf.constant(1)
@tf.function
def f(state, x):
  state += x
  return state
state = f(state, tf.constant(2))  # Pure functional style
state
<tf.Tensor:... numpy=3>

Python 操作每次跟踪只执行一次

func 可能包含混合了纯 Python 操作的 TensorFlow 操作。但是,当函数执行时,只会运行 TensorFlow 操作。 Python 操作仅在跟踪时运行一次。如果 TensorFlow 操作依赖于 Pyhton 操作的结果,这些结果将被冻结到图中。

@tf.function
def f(a, b):
  print('this runs at trace time; a is', a, 'and b is', b)
  return b
f(1, tf.constant(1))
this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(1, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f(2, tf.constant(1))
this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32)
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f(2, tf.constant(2))
<tf.Tensor:shape=(), dtype=int32, numpy=2>

使用类型注释来提高性能

'experimental_follow_type_hints` 可以与类型注释一起使用,通过自动将任何 Python 值强制转换为 tf.Tensor 来减少回溯(默认情况下不会这样做,除非您使用输入签名)。

@tf.function(experimental_follow_type_hints=True)
def f_with_hints(x:tf.Tensor):
  print('Tracing')
  return x
@tf.function(experimental_follow_type_hints=False)
def f_no_hints(x:tf.Tensor):
  print('Tracing')
  return x
f_no_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_no_hints(2)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=2>
f_with_hints(1)
Tracing
<tf.Tensor:shape=(), dtype=int32, numpy=1>
f_with_hints(2)
<tf.Tensor:shape=(), dtype=int32, numpy=2>

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.function。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。