當前位置: 首頁>>算法&結構>>正文


Tensorflow:如何保存/恢複模型?匯總整理

在Tensorflow中訓練一個模型之後:

  1. 如何保存訓練得到的模型?

  2. 如何恢複(重新加載)這個保存的模型?

最佳解決辦法

為保存和恢複模型添加更多細節功能,下麵的答案在持續改進中。

對Tensorflow版本0.11以及之後的版本:

保存模型:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

恢複模型(重新加載模型):

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

想了解更多信息可以參考:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

次佳解決辦法

對於TensorFlow版本0.11.0RC1以及之後的版本,可以直接通過調用tf.train.export_meta_graphtf.train.import_meta_graph(根據https://www.tensorflow.org/programmers_guide/meta_graph)保存和恢複模型

保存模式:

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

恢複模式:

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

第三種解決辦法

對於TensorFlow版本< 0.11.0RC1:

保存的檢查點包含模型中的Variable們的值,而不是模型/圖形本身,這意味著恢複檢查點時對於圖形應該一樣。

下麵是一個線性回歸的例子,其中有一個保存變量檢查點的訓練循環和一個評估部分,它將恢複在之前的運行中保存的變量並計算預測結果。當然,也可以恢複變量並繼續進行訓練。

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

以下是關於Variable的文檔docs,其中包括保存和恢複。這裏是關於Saver的文檔docs

第四種辦法

模型有兩部分,第一部分:模型定義,由Supervisor作為模型目錄中的graph.pbtxt保存;第二部分:張量的數值,保存到model.ckpt-1003418等檢查點文件中。

可以使用tf.import_graph_def恢複模型定義,並使用Saver恢複權重。

然而,Saver使用綁定到模型Graph的特殊集合保存變量列表,並且該集合不是使用import_graph_def初始化的,所以不能一起使用這兩個(未來會修複這個問題)。目前,還必須手動構建具有相同節點名稱的圖,並使用Saver將權重加載到其中。

(或者,您可以使用import_graph_def,手動創建變量,並為每個變量使用tf.add_to_collection(tf.GraphKeys.VARIABLES, variable),然後使用Saver)

第五種辦法

也可以采取更簡單的方法:

步驟1 – 初始化所有變量

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

步驟2 – 將模型Saver中的列表

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

步驟3 – 恢複模型(重新加載模型)

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

步驟4 – 檢查變量

W1 = session.run(W1)
print(W1)

當在不同的python實例中運行時,使用

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

第六種辦法

可以通過導入Graph,手動創建變量,然後使用保護程序,從graph_def和檢查點中進行恢複。

實現的代碼如下:

鏈接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(這當然是一種hack的方式,並不能保證這樣保存的模型在以後版本的TensorFlow中保持可讀。)

第七種辦法

如果是一個內部保存的模型,那麽隻需為所有變量指定恢複器即可

restorer = tf.train.Saver(tf.all_variables())

並使用它來恢複當前會話中的變量:

restorer.restore(self._sess, model_file)

對於外部模型,需要指定從外部變量名稱到本地變量名稱的映射。可以使用該命令查看模型變量名稱

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

inspect_checkpoint.py腳本可以在Tensorflow源的’./tensorflow/python/tools’文件夾中找到。

要指定映射,可以使用Tensorflow-Worklab,它包含一組類和腳本來訓練和重新訓練不同的模型。還包括一個重新訓練ResNet模型的例子,位於這裏

第八種辦法

在大多數情況下,使用tf.train.Saver從磁盤保存和恢複是最好的選擇:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

還可以保存/恢複graph結構(有關詳細信息,請參閱MetaGraph documentation)。默認情況下,Saver將graph結構保存到.meta文件中。可以調用import_meta_graph()來恢複它。恢複graph結構並返回一個可用於恢複模型狀態的Saver

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

但是,有些情況需要更快的速度。例如,如果要實現早期停止,則希望在訓練期間(如驗證集中測量)每次改進模型時保存檢查點,那麽如果某段時間內沒有進展,則要回滾到最佳模型。如果將模型保存到磁盤上,並且每次都有提升的情況下,這將大大減慢訓練速度。訣竅是將變量狀態保存到內存中,然後稍後恢複它們:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

快速說明:創建變量X時,TensorFlow會自動創建一個賦值操作X/Assign來設置變量的初始值。我們隻需使用這些現有的賦值操作,而不是創建占位符和額外的賦值操作(這隻會使graph變亂)。每個賦值op的第一個輸入是對應該初始化的變量的引用,第二個輸入(assign_op.inputs[1])是初始值。所以為了設置我們想要的任何值(而不是初始值),需要使用feed_dict並替換初始值。TensorFlow可以為任何操作提供一個值,而不僅僅是占位符。

第九種辦法

這是兩個基本情況的簡單解決方案,不同之處在於是否要從文件加載graph或在運行時構建graph。

這個答案適用於Tensorflow 0.12+(包括1.0)。

在代碼中重建graph

保存

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加載

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

從文件加載graph

使用此技術時,請確保所有層/變量都已明確設置唯一的名稱,否則Tensorflow將自己創建獨一無二的名稱,這回導致與存儲在文件中的名稱不同。這在以前的技術中不是問題,因為在加載和保存時,名稱都是”mangled”。

保存

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

加載

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

tensorflow

本文匯總整理自帖子:

本文由《純淨天空》出品。文章地址: https://vimsky.com/zh-tw/article/3614.html,未經允許,請勿轉載。