當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.callbacks.TensorBoard用法及代碼示例


為 TensorBoard 啟用可視化。

繼承自:Callback

用法

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True,
    write_images=False, write_steps_per_second=False, update_freq='epoch',
    profile_batch=0, embeddings_freq=0, embeddings_metadata=None, **kwargs
)

參數

  • log_dir 保存要被 TensorBoard 解析的日誌文件的目錄路徑。例如log_dir = os.path.join(working_dir, 'logs') 此目錄不應被任何其他回調重用。
  • histogram_freq 計算模型層的激活和權重直方圖的頻率(以時期為單位)。如果設置為 0,則不會計算直方圖。必須為直方圖可視化指定驗證數據(或拆分)。
  • write_graph 是否在 TensorBoard 中可視化圖形。當 write_graph 設置為 True 時,日誌文件可能會變得非常大。
  • write_images 是否編寫模型權重以在 TensorBoard 中可視化為圖像。
  • write_steps_per_second 是否將每秒的訓練步數記錄到 Tensorboard 中。這支持時代和批量頻率記錄。
  • update_freq 'batch''epoch' 或整數。使用 'batch' 時,在每批之後將損失和指標寫入 TensorBoard。這同樣適用於 'epoch' 。如果使用整數,比如說 1000 ,回調將每 1000 個批次將指標和損失寫入 TensorBoard。請注意,過於頻繁地寫入 TensorBoard 會減慢您的訓練速度。
  • profile_batch 分析批次以采樣計算特征。 profile_batch 必須是非負整數或整數元組。一對正整數表示要分析的批次範圍。默認情況下,分析是禁用的。
  • embeddings_freq 嵌入層將被可視化的頻率(以時期為單位)。如果設置為 0,嵌入將不會被可視化。
  • embeddings_metadata 將嵌入層名稱映射到文件的文件名的字典,在該文件中保存嵌入層的元數據。如果要對所有嵌入層使用相同的元數據文件,則可以傳遞單個文件名。

TensorBoard 是 TensorFlow 提供的可視化工具。

此回調記錄 TensorBoard 的事件,包括:

  • 指標匯總圖
  • 訓練圖可視化
  • 激活直方圖
  • 采樣分析

Model.evaluate 中使用時,除了紀元摘要之外,還會有一個摘要記錄評估指標與編寫的Model.optimizer.iterations。指標名稱將以 evaluation 開頭,其中 Model.optimizer.iterations 是可視化 TensorBoard 中的步驟。

如果你已經使用 pip 安裝了 TensorFlow,你應該能夠從命令行啟動 TensorBoard:

tensorboard --logdir=path_to_your_logs

您可以在此處找到有關 TensorBoard 的更多信息。

例子:

基本用法:

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# Then run the tensorboard command to view the visualizations.

子類模型中的自定義 batch-level 匯總:

class MyModel(tf.keras.Model):

  def build(self, _):
    self.dense = tf.keras.layers.Dense(10)

  def call(self, x):
    outputs = self.dense(x)
    tf.summary.histogram('outputs', outputs)
    return outputs

model = MyModel()
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N batches.
# In addition to any `tf.summary` contained in `Model.call`, metrics added in
# `Model.compile` will be logged every N batches.
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

函數 API 模型中的自定義 batch-level 摘要:

def my_summary(x):
  tf.summary.histogram('x', x)
  return x

inputs = tf.keras.Input(10)
x = tf.keras.layers.Dense(10)(inputs)
outputs = tf.keras.layers.Lambda(my_summary)(x)
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N batches.
# In addition to any `tf.summary` contained in `Model.call`, metrics added in
# `Model.compile` will be logged every N batches.
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

分析:

# Profile a single batch, e.g. the 5th batch.
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

# Profile a range of batches, e.g. from 10 to 20.
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=(10,20))
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.callbacks.TensorBoard。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。