当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.debugging.enable_check_numerics用法及代码示例


以即刻(Eager)/图(Graph)统一的方式启用张量数字检查。

用法

tf.debugging.enable_check_numerics(
    stack_height_limit=30, path_length_limit=50
)

参数

  • stack_height_limit 限制为打印堆栈跟踪的高度。仅适用于 tf.function s(图表)中的操作。
  • path_length_limit 限制在打印的堆栈跟踪中包含的文件路径。仅适用于 tf.function s(图表)中的操作。

一旦操作的输出张量包含无穷大或 NaN,数字检查机制将导致任何 TensorFlow 即刻执行或图形执行出错。

这种方法是幂等的。多次调用它与调用一次具有相同的效果。

该方法只在调用它的线程上生效。

当 op 的浮点型输出张量包含任何 Infinity 或 NaN 时,将抛出 tf.errors.InvalidArgumentError,并带有显示以下信息的错误消息:

  • 生成带有错误数值的张量的操作的类型。
  • 张量的数据类型(dtype)。
  • 张量的形状(在即刻执行或图形构建时已知的程度)。
  • 包含图的名称(如果可用)。
  • (仅限图形模式):intra-graph 操作创建的堆栈跟踪,具有 stack-height 限制和 path-length 限制以提高视觉清晰度。属于用户代码(相对于 tensorflow 的内部代码)的堆栈帧用文本箭头 ("->") 突出显示。
  • (仅即刻模式):有多少违规张量的元素分别是 InfinityNaN

启用后,可以使用 tf.debugging.disable_check_numerics() 禁用 check-numerics 机制。

示例用法:

  1. 在执行 tf.function 图期间捕获无穷大:

    import tensorflow as tf
    
    tf.debugging.enable_check_numerics()
    
    @tf.function
    def square_log_x_plus_1(x):
      v = tf.math.log(x + 1)
      return tf.math.square(v)
    
    x = -1.0
    
    # When the following line runs, a function graph will be compiled
    # from the Python function `square_log_x_plus_1()`. Due to the
    # `enable_check_numerics()` call above, the graph will contain
    # numerics checking ops that will run during the function graph's
    # execution. The function call generates an -infinity when the Log
    # (logarithm) op operates on the output tensor of the Add op.
    # The program errors out at this line, printing an error message.
    y = square_log_x_plus_1(x)
    z = -y
  2. 在即刻执行期间捕获 NaN:

    import numpy as np
    import tensorflow as tf
    
    tf.debugging.enable_check_numerics()
    
    x = np.array([[0.0, -1.0], [4.0, 3.0]])
    
    # The following line executes the Sqrt op eagerly. Due to the negative
    # element in the input array, a NaN is generated. Due to the
    # `enable_check_numerics()` call above, the program errors immediately
    # at this line, printing an error message.
    y = tf.math.sqrt(x)
    z = tf.matmul(y, y)

注意:如果您的代码在 TPU 上运行,请务必在调用 tf.debugging.enable_check_numerics() 之前调用 tf.config.set_soft_device_placement(True),因为此 API 在 TPU 上使用自动外部编译。例如:

tf.config.set_soft_device_placement(True)
tf.debugging.enable_check_numerics()

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
strategy = tf.distribute.TPUStrategy(resolver)
with strategy.scope():
  # ...

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.debugging.enable_check_numerics。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。