Update neutral_network_with_tensorflow.py

This commit is contained in:
guanjihuan 2023-09-05 21:24:21 +08:00
parent 8a84c701e3
commit ef33036001

View File

@ -1,4 +1,11 @@
import tensorflow as tf
"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/154
"""
# import tensorflow as tf
import tensorflow.compat.v1 as tf # 之所以这么调用是因为tensorflow版本2.0无法兼容版本1.0
tf.compat.v1.disable_eager_execution() # 这行代码可以保证 sess.run() 能够正常运行
import numpy as np
import matplotlib.pyplot as plt
@ -63,7 +70,7 @@ plt.show()
# 保存训练好的神经网络模型tf.train.Saver()
saver = tf.train.Saver()
save_path = saver.save(sess, "my_net/save_net.ckpt") # 保存模型
save_path = saver.save(sess, "./my_net/save_net.ckpt") # 保存模型
print("Save to path: ", save_path)
print()
sess.close() # 关闭会话
@ -71,7 +78,7 @@ sess.close() # 关闭会话
# 调用神经网络模型,来预测新的值
with tf.Session() as sess2:
saver.restore(sess2, "my_net/save_net.ckpt") # 提取模型中的所有变量
saver.restore(sess2, "./my_net/save_net.ckpt") # 提取模型中的所有变量
print(y_data[0, :]) # 输出的原始值
print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值