当前位置: 首页>>代码示例 >>用法及示例精选 >>正文


Python tf.keras.utils.plot_model用法及代码示例


将 Keras 模型转换为点格式并保存到文件中。

用法

tf.keras.utils.plot_model(
    model, to_file='model.png', show_shapes=False, show_dtype=False,
    show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96,
    layer_range=None, show_layer_activations=False
)

参数

  • model 一个 Keras 模型实例
  • to_file 绘图图像的文件名。
  • show_shapes 是否显示形状信息。
  • show_dtype 是否显示图层数据类型。
  • show_layer_names 是否显示图层名称。
  • rankdir rankdir 参数传递给 PyDot,一个指定绘图格式的字符串:'TB' 创建一个垂直绘图; 'LR' 创建水平图。
  • expand_nested 是否将嵌套模型扩展为集群。
  • dpi 每英寸点数。
  • layer_range list 的输入,包含两个 str 项,即起始层名称和结束层名称(均包括在内),指示将为其生成绘图的层范围。它还接受正则表达式模式而不是确切的名称。在这种情况下,开始谓词将是它与 layer_range[0] 匹配的第一个元素,而结束谓词将是它与 layer_range[1] 匹配的最后一个元素。默认情况下 None 考虑模型的所有层。请注意,您必须传递范围,以便生成的子图必须是完整的。
  • show_layer_activations 显示层激活(仅适用于具有activation 属性的层)。

抛出

  • ValueError 如果在构建模型之前调用plot_model

返回

  • 如果安装了 Jupyter,则为 Jupyter notebook Image 对象。这将启用 in-line 在笔记本中显示模型图。

例子:

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

相关用法


注:本文由纯净天空筛选整理自tensorflow.org大神的英文原创作品 tf.keras.utils.plot_model。非经特殊声明,原始代码版权归原作者所有,本译文未经允许或授权,请勿转载或复制。