好不容易花了一大堆時間訓練好的結果,如果沒有儲存起來可就浪費了,不但要再重新訓練一次,也沒辦法將這個結果分享給他人,同時也不方便建置在系統裡面。
所以今天就來學習怎麼儲存和重複利用吧。
這次用的範例就是之前的用 TensorFlow 架構CNN網路層 來做手寫數字辨識,這篇文章的內容去做增加,那要增加什麼呢?
首先要知道的事情!
TensorFlow要儲存結果,就是要用到 tf.train.Saver 這個物件
像要多知道它的內容嗎?那就去這邊
這邊我想要多說說實際上使用的方式,那這邊我們分兩個階段來說說:
- 儲存
- 讀取
儲存比較簡單,只要在最後面加上:
# save model saver = tf.train.Saver() save_path = saver.save(sess, "net/save_net.ckpt") print("Save to path: ", save_path)
這樣在同個目錄夾中就會多出一個叫 “net”的資料夾,打開之後就會向下圖:
儲存方便就只要這樣就可以了,再來是該說說讀取。
那我還是先貢上全碼:GitHub
基本上最重要的地方就是,在讀取結果前,一定要先搭建一個相同的模型,因為儲存其實只會儲存訓練出來的參數,你必須先幫它做好容器,他才能將這些參數導入至這些容器裡面。
那建立好相同模型後,再來只要打:
with tf.Session() as sess: saver.restore(sess, "net/save_net.ckpt")
記得一定要先有容器,然後才能建立Session(),否則會程式會報錯的。
讀取完後就能跟之前一樣進行辨識了喔~
#執行100次辨識 with tf.Session() as sess: saver.restore(sess, "net/save_net.ckpt") #print(sess.run(global_step_tensor)) for i in range(100): GiveAnswer(mnist.test.images[3]) print("ans:",list(mnist.test.labels[3]).index(1))
執行結果如最上圖
1 thought on “儲存 tensorFlow 訓練結果與重複使用”