본문 바로가기

머신러닝 & 텐서플로 & 파이썬

Tensorflow training을 save하고 restore하는 방법

반응형



training modeling 을 저장했다가 다시 불러서 사용하게 되면,

처음부터 training을 할 필요가 없어져서 매우 유용합니다.


그 방법 역시 매우 간단한데요. 아래 예제가 있습니다.

한번 활용해보세요.


optimizer = tf.train.AdamOptimizer(learning_rate)

train = optimizer.minimize(cost)


saver = tf.train.Saver()

sess = tf.Session()

sess.run(tf.global_variables_initializer())


# training 을 저장할 file 명

mytrain= "./mytain.ckpt"


if os.path.exists(mytrain+".meta"):

# 파일에서 loading

    saver.restore(sess, mytrain)


else :

    for step in range(2001) :

        cost_val, hy_val, _ = sess.run([cost, hypothesis, train], feed_dict={X: trainX, Y: trainY})

        if step % 10 ==0 :

            print(step, "Cost: ", cost_val, "\nPrediction:\n", hy_val)


#file에 저장

    saver.save(sess, mytrain)


test_predict = sess.run(hypothesis, feed_dict= {X: testX})





아래와 같이 graph 까지 복원해서 사용할 수 있는 방법도 있습니다.

saver = tf.train.import_meta_graph(mytrain+".meta")

saver.restore(sess, mytrain)




ex)

save



train_size = int(len(dataY) * 0.7)

test_size =  len(dataY) - train_size


trainX = np.array(dataX[0:train_size])

testX = np.array(dataX[train_size:len(dataX)])


trainY = np.array(dataY[0:train_size])

testY = np.array(dataY[train_size:len(dataY)])


input_len = data_dim*seq_length


X = tf.placeholder(tf.float32, [None, input_len], name='X')

Y = tf.placeholder(tf.float32, [None, 1], name='Y')


W = tf.Variable(tf.random_normal([input_len, 1]), name='weight')

b = tf.Variable(tf.random_normal([1]), name='bias')


#hypothesis = tf.matmul(X, W) + b

hypothesis = tf.add(b, tf.matmul(X, W), name="h")




cost = tf.reduce_mean(tf.square(hypothesis - Y))


optimizer = tf.train.AdamOptimizer(learning_rate)

train = optimizer.minimize(cost)


saver = tf.train.Saver()

sess = tf.Session()

sess.run(tf.global_variables_initializer())



for step in range(2001) :

    cost_val, hy_val, _ = sess.run([cost, hypothesis, train], feed_dict={X: trainX, Y: trainY})

    if step % 10 ==0 :

        print(step, "Cost: ", cost_val, "\nPrediction:\n", hy_val)


saver.save(sess, mytrain)





resotre

if os.path.exists(mytrain+".meta"):

    sess = tf.Session()

    sess.run(tf.global_variables_initializer())

    saver = tf.train.import_meta_graph(mytrain+".meta")


    saver.restore(sess, mytrain)


    graph = tf.get_default_graph()

    X = graph.get_tensor_by_name('X:0')

    W = graph.get_tensor_by_name('weight:0')

    b = graph.get_tensor_by_name('bias:0')


    hypothesis = graph.get_tensor_by_name('h:0')   #tf.matmul(X, W) + b



    test_predict = sess.run(hypothesis, feed_dict={X: testX})




'머신러닝 & 텐서플로 & 파이썬' 카테고리의 다른 글

데이타 크롤링  (0) 2018.04.02
머신러닝 유투브 영상모음  (0) 2018.03.29
Matplotlib 사용하기  (0) 2018.03.20
adam and gradient descent optimizer  (0) 2018.01.29
Softmax  (0) 2018.01.26