當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.debugging.enable_check_numerics用法及代碼示例


以即刻(Eager)/圖(Graph)統一的方式啟用張量數字檢查。

用法

tf.debugging.enable_check_numerics(
    stack_height_limit=30, path_length_limit=50
)

參數

  • stack_height_limit 限製為打印堆棧跟蹤的高度。僅適用於 tf.function s(圖表)中的操作。
  • path_length_limit 限製在打印的堆棧跟蹤中包含的文件路徑。僅適用於 tf.function s(圖表)中的操作。

一旦操作的輸出張量包含無窮大或 NaN,數字檢查機製將導致任何 TensorFlow 即刻執行或圖形執行出錯。

這種方法是冪等的。多次調用它與調用一次具有相同的效果。

該方法隻在調用它的線程上生效。

當 op 的浮點型輸出張量包含任何 Infinity 或 NaN 時,將拋出 tf.errors.InvalidArgumentError,並帶有顯示以下信息的錯誤消息:

  • 生成帶有錯誤數值的張量的操作的類型。
  • 張量的數據類型(dtype)。
  • 張量的形狀(在即刻執行或圖形構建時已知的程度)。
  • 包含圖的名稱(如果可用)。
  • (僅限圖形模式):intra-graph 操作創建的堆棧跟蹤,具有 stack-height 限製和 path-length 限製以提高視覺清晰度。屬於用戶代碼(相對於 tensorflow 的內部代碼)的堆棧幀用文本箭頭 ("->") 突出顯示。
  • (僅即刻模式):有多少違規張量的元素分別是 InfinityNaN

啟用後,可以使用 tf.debugging.disable_check_numerics() 禁用 check-numerics 機製。

示例用法:

  1. 在執行 tf.function 圖期間捕獲無窮大:

    import tensorflow as tf
    
    tf.debugging.enable_check_numerics()
    
    @tf.function
    def square_log_x_plus_1(x):
      v = tf.math.log(x + 1)
      return tf.math.square(v)
    
    x = -1.0
    
    # When the following line runs, a function graph will be compiled
    # from the Python function `square_log_x_plus_1()`. Due to the
    # `enable_check_numerics()` call above, the graph will contain
    # numerics checking ops that will run during the function graph's
    # execution. The function call generates an -infinity when the Log
    # (logarithm) op operates on the output tensor of the Add op.
    # The program errors out at this line, printing an error message.
    y = square_log_x_plus_1(x)
    z = -y
  2. 在即刻執行期間捕獲 NaN:

    import numpy as np
    import tensorflow as tf
    
    tf.debugging.enable_check_numerics()
    
    x = np.array([[0.0, -1.0], [4.0, 3.0]])
    
    # The following line executes the Sqrt op eagerly. Due to the negative
    # element in the input array, a NaN is generated. Due to the
    # `enable_check_numerics()` call above, the program errors immediately
    # at this line, printing an error message.
    y = tf.math.sqrt(x)
    z = tf.matmul(y, y)

注意:如果您的代碼在 TPU 上運行,請務必在調用 tf.debugging.enable_check_numerics() 之前調用 tf.config.set_soft_device_placement(True),因為此 API 在 TPU 上使用自動外部編譯。例如:

tf.config.set_soft_device_placement(True)
tf.debugging.enable_check_numerics()

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
strategy = tf.distribute.TPUStrategy(resolver)
with strategy.scope():
  # ...

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.debugging.enable_check_numerics。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。