儲存 tensorFlow 訓練結果與重複使用

好不容易花了一大堆時間訓練好的結果,如果沒有儲存起來可就浪費了,不但要再重新訓練一次,也沒辦法將這個結果分享給他人,同時也不方便建置在系統裡面。

所以今天就來學習怎麼儲存和重複利用吧。

未命名

這次用的範例就是之前的用 TensorFlow 架構CNN網路層 來做手寫數字辨識,這篇文章的內容去做增加,那要增加什麼呢?

首先要知道的事情!

TensorFlow要儲存結果,就是要用到 tf.train.Saver 這個物件

像要多知道它的內容嗎?那就去這邊

這邊我想要多說說實際上使用的方式,那這邊我們分兩個階段來說說:

  1. 儲存
  2. 讀取

儲存比較簡單,只要在最後面加上:

# save model
saver = tf.train.Saver()
save_path = saver.save(sess, "net/save_net.ckpt")
print("Save to path: ", save_path)

這樣在同個目錄夾中就會多出一個叫 “net”的資料夾,打開之後就會向下圖:

未命名.png

儲存方便就只要這樣就可以了,再來是該說說讀取。

那我還是先貢上全碼:GitHub

基本上最重要的地方就是,在讀取結果前,一定要先搭建一個相同的模型,因為儲存其實只會儲存訓練出來的參數,你必須先幫它做好容器,他才能將這些參數導入至這些容器裡面。

那建立好相同模型後,再來只要打:

with tf.Session() as sess:
    saver.restore(sess, "net/save_net.ckpt")

記得一定要先有容器,然後才能建立Session(),否則會程式會報錯的。

讀取完後就能跟之前一樣進行辨識了喔~

#執行100次辨識
with tf.Session() as sess:
    saver.restore(sess, "net/save_net.ckpt")
    #print(sess.run(global_step_tensor))
    for i in range(100):
        GiveAnswer(mnist.test.images[3])
        print("ans:",list(mnist.test.labels[3]).index(1))

執行結果如最上圖

One Comment

Add a Comment

發佈留言必須填寫的電子郵件地址不會公開。