Update neutral_network_with_tensorflow.py
This commit is contained in:
parent
8a84c701e3
commit
ef33036001
@ -1,4 +1,11 @@
|
||||
import tensorflow as tf
|
||||
"""
|
||||
This code is supported by the website: https://www.guanjihuan.com
|
||||
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/154
|
||||
"""
|
||||
|
||||
# import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tf # 之所以这么调用,是因为tensorflow版本2.0无法兼容版本1.0
|
||||
tf.compat.v1.disable_eager_execution() # 这行代码可以保证 sess.run() 能够正常运行
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@ -63,7 +70,7 @@ plt.show()
|
||||
|
||||
# 保存训练好的神经网络模型tf.train.Saver()
|
||||
saver = tf.train.Saver()
|
||||
save_path = saver.save(sess, "my_net/save_net.ckpt") # 保存模型
|
||||
save_path = saver.save(sess, "./my_net/save_net.ckpt") # 保存模型
|
||||
print("Save to path: ", save_path)
|
||||
print()
|
||||
sess.close() # 关闭会话
|
||||
@ -71,7 +78,7 @@ sess.close() # 关闭会话
|
||||
|
||||
# 调用神经网络模型,来预测新的值
|
||||
with tf.Session() as sess2:
|
||||
saver.restore(sess2, "my_net/save_net.ckpt") # 提取模型中的所有变量
|
||||
saver.restore(sess2, "./my_net/save_net.ckpt") # 提取模型中的所有变量
|
||||
print(y_data[0, :]) # 输出的原始值
|
||||
print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user