Update neutral_network_with_tensorflow.py
This commit is contained in:
		| @@ -1,4 +1,11 @@ | |||||||
| import tensorflow as tf | """ | ||||||
|  | This code is supported by the website: https://www.guanjihuan.com | ||||||
|  | The newest version of this code is on the web page: https://www.guanjihuan.com/archives/154 | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | # import tensorflow as tf | ||||||
|  | import tensorflow.compat.v1 as tf  # 之所以这么调用,是因为tensorflow版本2.0无法兼容版本1.0 | ||||||
|  | tf.compat.v1.disable_eager_execution()  # 这行代码可以保证 sess.run() 能够正常运行 | ||||||
| import numpy as np | import numpy as np | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
|  |  | ||||||
| @@ -63,7 +70,7 @@ plt.show() | |||||||
|  |  | ||||||
| # 保存训练好的神经网络模型tf.train.Saver() | # 保存训练好的神经网络模型tf.train.Saver() | ||||||
| saver = tf.train.Saver() | saver = tf.train.Saver() | ||||||
| save_path = saver.save(sess, "my_net/save_net.ckpt")  # 保存模型 | save_path = saver.save(sess, "./my_net/save_net.ckpt")  # 保存模型 | ||||||
| print("Save to path: ", save_path) | print("Save to path: ", save_path) | ||||||
| print() | print() | ||||||
| sess.close()  # 关闭会话 | sess.close()  # 关闭会话 | ||||||
| @@ -71,7 +78,7 @@ sess.close()  # 关闭会话 | |||||||
|  |  | ||||||
| # 调用神经网络模型,来预测新的值 | # 调用神经网络模型,来预测新的值 | ||||||
| with tf.Session() as sess2: | with tf.Session() as sess2: | ||||||
|     saver.restore(sess2, "my_net/save_net.ckpt")  # 提取模型中的所有变量 |     saver.restore(sess2, "./my_net/save_net.ckpt")  # 提取模型中的所有变量 | ||||||
|     print(y_data[0, :])  # 输出的原始值 |     print(y_data[0, :])  # 输出的原始值 | ||||||
|     print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]}))  # 预测值 |     print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]}))  # 预测值 | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user