當前位置: 首頁>>代碼示例 >>用法及示例精選 >>正文


Python tf.compat.v1.train.import_meta_graph用法及代碼示例


重新創建保存在 MetaGraphDef 原型中的圖形。

用法

tf.compat.v1.train.import_meta_graph(
    meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs
)

參數

  • meta_graph_or_file MetaGraphDef 協議緩衝區或包含 MetaGraphDef 的文件名(包括路徑)。
  • clear_devices 導入期間是否清除OperationTensor 的設備字段。
  • import_scope 可選 string 。要添加的名稱範圍。僅在從協議緩衝區初始化時使用。
  • **kwargs 可選的鍵控參數。

返回

  • 一個由saver_defMetaGraphDef或無。

    如果 MetaGraphDef 中不存在變量(即沒有要恢複的變量),則返回 None 值。

拋出

  • RuntimeError 如果在啟用即刻執行的情況下調用。

此函數將 MetaGraphDef 協議緩衝區作為輸入。如果參數是包含 MetaGraphDef 協議緩衝區的文件,它會根據文件內容構造協議緩衝區。然後,該函數將 graph_def 字段中的所有節點添加到當前圖形,重新創建所有集合,並返回從 saver_def 字段構造的保護程序。

export_meta_graph() 結合使用,此函數可用於

  • 將圖形與其他 Python 對象(例如 QueueRunner , Variable)序列化為 MetaGraphDef

  • 從保存的圖表和檢查點重新開始訓練。

  • 從保存的圖形和檢查點運行推理。

...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.compat.v1.add_to_collection('train_op', train_op)
sess = tf.compat.v1.Session()
for step in range(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # Saves checkpoint, which by default also exports a meta_graph
        # named 'my-model-global_step.meta'.
        saver.save(sess, 'my-model', global_step=step)

稍後我們可以從保存的 meta_graph 繼續訓練,而無需從頭開始構建模型。

with tf.Session() as sess:
  new_saver =
  tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
  new_saver.restore(sess, 'my-save-dir/my-model-10000')
  # tf.get_collection() returns a list. In this example we only want
  # the first one.
  train_op = tf.get_collection('train_op')[0]
  for step in range(1000000):
    sess.run(train_op)

注意:從保存的 meta_graph 重新開始訓練僅在設備分配未更改的情況下有效。

例子:

還可以存儲變量、占位符和獨立操作,如下例所示。

# Saving contents and operations.
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.math.multiply(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")

稍後可以恢複此模型並加載內容。

# Restoring variables and running operations.
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0":12.0, "v2:0":3.3})
print(result)

eager模式兼容性

不支持導出/導入元圖。啟用即刻執行時不存在圖表。

相關用法


注:本文由純淨天空篩選整理自tensorflow.org大神的英文原創作品 tf.compat.v1.train.import_meta_graph。非經特殊聲明,原始代碼版權歸原作者所有,本譯文未經允許或授權,請勿轉載或複製。