当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.compat.v1.py_func用法及代码示例


包装一个 python 函数并将其用作 TensorFlow 操作。

用法

tf.compat.v1.py_func(
    func, inp, Tout, stateful=True, name=None
)

参数

  • func 一个 Python 函数,它接受 ndarray 对象作为参数并返回 ndarray 对象列表(或单个 ndarray )。此函数必须接受与 inp 中的张量一样多的参数,并且这些参数类型将匹配 inp 中相应的 tf.Tensor 对象。返回的 ndarray 必须与定义的数量和类型相匹配 Tout 。重要提示:func 的输入和输出 numpy ndarray 不保证是副本。在某些情况下,它们的底层内存将与相应的 TensorFlow 张量共享。在没有显式 (np.) 复制的情况下,在 python 数据结构中就地修改或存储 func 输入或返回值可能会产生不确定的后果。
  • inp Tensor 对象的列表。
  • Tout tensorflow 数据类型的列表或元组或单个 tensorflow 数据类型(如果只有一个),指示 func 返回的内容。
  • stateful (布尔值。)如果为 True,则该函数应被视为有状态。如果一个函数是无状态的,当给定相同的输入时,它将返回相同的输出并且没有可观察到的副作用。诸如公共子表达式消除之类的优化仅在无状态操作上执行。
  • name 操作的名称(可选)。

返回

  • Tensor 的列表或 func 计算的单个 Tensor

迁移到 TF2

警告:这个 API 是为 TensorFlow v1 设计的。继续阅读有关如何从该 API 迁移到本机 TensorFlow v2 等效项的详细信息。见TensorFlow v1 到 TensorFlow v2 迁移指南有关如何迁移其余代码的说明。

此名称在 TF2 中已弃用并删除,但 tf.numpy_function 是 near-exact 替换,只需删除 stateful 参数(所有 tf.numpy_function 调用都被视为有状态)。它与即刻执行和 tf.function 兼容。

tf.py_function 是一个接近但不是完全替代品,将 TensorFlow 张量传递给包装函数而不是 NumPy 数组,后者提供梯度并可以利用加速器。

前:

def fn_using_numpy(x):
  x[0] = 0.
  return x
tf.compat.v1.py_func(fn_using_numpy, inp=[tf.constant([1., 2.])],
    Tout=tf.float32, stateful=False)
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>

后:

tf.numpy_function(fn_using_numpy, inp=[tf.constant([1., 2.])],
    Tout=tf.float32)
<tf.Tensor:shape=(2,), dtype=float32, numpy=array([0., 2.], dtype=float32)>

给定一个 python 函数 func ,它将 numpy 数组作为其参数并返回 numpy 数组作为其输出,将此函数包装为 TensorFlow 图中的一个操作。以下代码段构建了一个简单的 TensorFlow 图,该图调用 np.sinh() NumPy 函数作为图中的操作:

def my_func(x):
  # x will be a numpy array with the contents of the placeholder below
  return np.sinh(x)
input = tf.compat.v1.placeholder(tf.float32)
y = tf.compat.v1.py_func(my_func, [input], tf.float32)

注意:tf.compat.v1.py_func() 操作具有以下已知限制:

  • 函数体(即 func )不会在 GraphDef 中序列化。因此,如果您需要序列化模型并在不同的环境中恢复它,则不应使用此函数。

  • 该操作必须在与调用 tf.compat.v1.py_func() 的 Python 程序相同的地址空间中运行。如果您使用分布式 TensorFlow,则必须在与调用 tf.compat.v1.py_func() 的程序相同的进程中运行 tf.distribute.Server,并且必须将创建的操作固定到该服务器中的设备(例如,使用 with tf.device(): )。

注意:它产生未知形状和等级的张量,因为形状推断不适用于任意 Python 代码。如果您需要形状,则需要根据静态可用信息进行设置。

例如:

import tensorflow as tf
  import numpy as np

  def make_synthetic_data(i):
      return np.cast[np.uint8](i) * np.ones([20,256,256,3],
              dtype=np.float32) / 10.

  def preprocess_fn(i):
      ones = tf.py_function(make_synthetic_data,[i],tf.float32)
      ones.set_shape(tf.TensorShape([None, None, None, None]))
      ones = tf.image.resize(ones, [224,224])
      return ones

  ds = tf.data.Dataset.range(10)
  ds = ds.map(preprocess_fn)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.compat.v1.py_func。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。