Update neutral_network_with_tensorflow.py
This commit is contained in:
		| @@ -1,4 +1,11 @@ | ||||
| import tensorflow as tf | ||||
| """ | ||||
| This code is supported by the website: https://www.guanjihuan.com | ||||
| The newest version of this code is on the web page: https://www.guanjihuan.com/archives/154 | ||||
| """ | ||||
|  | ||||
| # import tensorflow as tf | ||||
| import tensorflow.compat.v1 as tf  # 之所以这么调用,是因为tensorflow版本2.0无法兼容版本1.0 | ||||
| tf.compat.v1.disable_eager_execution()  # 这行代码可以保证 sess.run() 能够正常运行 | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
|  | ||||
| @@ -63,7 +70,7 @@ plt.show() | ||||
|  | ||||
| # 保存训练好的神经网络模型tf.train.Saver() | ||||
| saver = tf.train.Saver() | ||||
| save_path = saver.save(sess, "my_net/save_net.ckpt")  # 保存模型 | ||||
| save_path = saver.save(sess, "./my_net/save_net.ckpt")  # 保存模型 | ||||
| print("Save to path: ", save_path) | ||||
| print() | ||||
| sess.close()  # 关闭会话 | ||||
| @@ -71,7 +78,7 @@ sess.close()  # 关闭会话 | ||||
|  | ||||
| # 调用神经网络模型,来预测新的值 | ||||
| with tf.Session() as sess2: | ||||
|     saver.restore(sess2, "my_net/save_net.ckpt")  # 提取模型中的所有变量 | ||||
|     saver.restore(sess2, "./my_net/save_net.ckpt")  # 提取模型中的所有变量 | ||||
|     print(y_data[0, :])  # 输出的原始值 | ||||
|     print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]}))  # 预测值 | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user