当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.callbacks.TensorBoard用法及代码示例


为 TensorBoard 启用可视化。

继承自:Callback

用法

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True,
    write_images=False, write_steps_per_second=False, update_freq='epoch',
    profile_batch=0, embeddings_freq=0, embeddings_metadata=None, **kwargs
)

参数

  • log_dir 保存要被 TensorBoard 解析的日志文件的目录路径。例如log_dir = os.path.join(working_dir, 'logs') 此目录不应被任何其他回调重用。
  • histogram_freq 计算模型层的激活和权重直方图的频率(以时期为单位)。如果设置为 0,则不会计算直方图。必须为直方图可视化指定验证数据(或拆分)。
  • write_graph 是否在 TensorBoard 中可视化图形。当 write_graph 设置为 True 时,日志文件可能会变得非常大。
  • write_images 是否编写模型权重以在 TensorBoard 中可视化为图像。
  • write_steps_per_second 是否将每秒的训练步数记录到 Tensorboard 中。这支持时代和批量频率记录。
  • update_freq 'batch''epoch' 或整数。使用 'batch' 时,在每批之后将损失和指标写入 TensorBoard。这同样适用于 'epoch' 。如果使用整数,比如说 1000 ,回调将每 1000 个批次将指标和损失写入 TensorBoard。请注意,过于频繁地写入 TensorBoard 会减慢您的训练速度。
  • profile_batch 分析批次以采样计算特征。 profile_batch 必须是非负整数或整数元组。一对正整数表示要分析的批次范围。默认情况下,分析是禁用的。
  • embeddings_freq 嵌入层将被可视化的频率(以时期为单位)。如果设置为 0,嵌入将不会被可视化。
  • embeddings_metadata 将嵌入层名称映射到文件的文件名的字典,在该文件中保存嵌入层的元数据。如果要对所有嵌入层使用相同的元数据文件,则可以传递单个文件名。

TensorBoard 是 TensorFlow 提供的可视化工具。

此回调记录 TensorBoard 的事件,包括:

  • 指标汇总图
  • 训练图可视化
  • 激活直方图
  • 采样分析

Model.evaluate 中使用时,除了纪元摘要之外,还会有一个摘要记录评估指标与编写的Model.optimizer.iterations。指标名称将以 evaluation 开头,其中 Model.optimizer.iterations 是可视化 TensorBoard 中的步骤。

如果你已经使用 pip 安装了 TensorFlow,你应该能够从命令行启动 TensorBoard:

tensorboard --logdir=path_to_your_logs

您可以在此处找到有关 TensorBoard 的更多信息。

例子:

基本用法:

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# Then run the tensorboard command to view the visualizations.

子类模型中的自定义 batch-level 汇总:

class MyModel(tf.keras.Model):

  def build(self, _):
    self.dense = tf.keras.layers.Dense(10)

  def call(self, x):
    outputs = self.dense(x)
    tf.summary.histogram('outputs', outputs)
    return outputs

model = MyModel()
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N batches.
# In addition to any `tf.summary` contained in `Model.call`, metrics added in
# `Model.compile` will be logged every N batches.
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

函数 API 模型中的自定义 batch-level 摘要:

def my_summary(x):
  tf.summary.histogram('x', x)
  return x

inputs = tf.keras.Input(10)
x = tf.keras.layers.Dense(10)(inputs)
outputs = tf.keras.layers.Lambda(my_summary)(x)
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', 'mse')

# Make sure to set `update_freq=N` to log a batch-level summary every N batches.
# In addition to any `tf.summary` contained in `Model.call`, metrics added in
# `Model.compile` will be logged every N batches.
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
model.fit(x_train, y_train, callbacks=[tb_callback])

分析:

# Profile a single batch, e.g. the 5th batch.
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

# Profile a range of batches, e.g. from 10 to 20.
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='./logs', profile_batch=(10,20))
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.callbacks.TensorBoard。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。