當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.keras.utils.plot_model用法及代碼示例


將 Keras 模型轉換為點格式並保存到文件中。

用法

tf.keras.utils.plot_model(
    model, to_file='model.png', show_shapes=False, show_dtype=False,
    show_layer_names=True, rankdir='TB', expand_nested=False, dpi=96,
    layer_range=None, show_layer_activations=False
)

參數

  • model 一個 Keras 模型實例
  • to_file 繪圖圖像的文件名。
  • show_shapes 是否顯示形狀信息。
  • show_dtype 是否顯示圖層數據類型。
  • show_layer_names 是否顯示圖層名稱。
  • rankdir rankdir 參數傳遞給 PyDot,一個指定繪圖格式的字符串:'TB' 創建一個垂直繪圖; 'LR' 創建水平圖。
  • expand_nested 是否將嵌套模型擴展為集群。
  • dpi 每英寸點數。
  • layer_range list 的輸入,包含兩個 str 項,即起始層名稱和結束層名稱(均包括在內),指示將為其生成繪圖的層範圍。它還接受正則表達式模式而不是確切的名稱。在這種情況下,開始謂詞將是它與 layer_range[0] 匹配的第一個元素,而結束謂詞將是它與 layer_range[1] 匹配的最後一個元素。默認情況下 None 考慮模型的所有層。請注意,您必須傳遞範圍,以便生成的子圖必須是完整的。
  • show_layer_activations 顯示層激活(僅適用於具有activation 屬性的層)。

拋出

  • ValueError 如果在構建模型之前調用plot_model

返回

  • 如果安裝了 Jupyter,則為 Jupyter notebook Image 對象。這將啟用 in-line 在筆記本中顯示模型圖。

例子:

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.keras.utils.plot_model。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。