Update neutral_network_with_tensorflow.py
This commit is contained in:
parent
8a84c701e3
commit
ef33036001
@ -1,4 +1,11 @@
|
|||||||
import tensorflow as tf
|
"""
|
||||||
|
This code is supported by the website: https://www.guanjihuan.com
|
||||||
|
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/154
|
||||||
|
"""
|
||||||
|
|
||||||
|
# import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tf # 之所以这么调用,是因为tensorflow版本2.0无法兼容版本1.0
|
||||||
|
tf.compat.v1.disable_eager_execution() # 这行代码可以保证 sess.run() 能够正常运行
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
@ -63,7 +70,7 @@ plt.show()
|
|||||||
|
|
||||||
# 保存训练好的神经网络模型tf.train.Saver()
|
# 保存训练好的神经网络模型tf.train.Saver()
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
save_path = saver.save(sess, "my_net/save_net.ckpt") # 保存模型
|
save_path = saver.save(sess, "./my_net/save_net.ckpt") # 保存模型
|
||||||
print("Save to path: ", save_path)
|
print("Save to path: ", save_path)
|
||||||
print()
|
print()
|
||||||
sess.close() # 关闭会话
|
sess.close() # 关闭会话
|
||||||
@ -71,7 +78,7 @@ sess.close() # 关闭会话
|
|||||||
|
|
||||||
# 调用神经网络模型,来预测新的值
|
# 调用神经网络模型,来预测新的值
|
||||||
with tf.Session() as sess2:
|
with tf.Session() as sess2:
|
||||||
saver.restore(sess2, "my_net/save_net.ckpt") # 提取模型中的所有变量
|
saver.restore(sess2, "./my_net/save_net.ckpt") # 提取模型中的所有变量
|
||||||
print(y_data[0, :]) # 输出的原始值
|
print(y_data[0, :]) # 输出的原始值
|
||||||
print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值
|
print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user