From ef33036001209a502f87db5321c52c516d527dca Mon Sep 17 00:00:00 2001 From: guanjihuan Date: Tue, 5 Sep 2023 21:24:21 +0800 Subject: [PATCH] Update neutral_network_with_tensorflow.py --- .../neutral_network_with_tensorflow.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/language_learning/python/2019.10.11_neutral_network_with_tensorflow/neutral_network_with_tensorflow.py b/language_learning/python/2019.10.11_neutral_network_with_tensorflow/neutral_network_with_tensorflow.py index a8990f6..71b5b7c 100644 --- a/language_learning/python/2019.10.11_neutral_network_with_tensorflow/neutral_network_with_tensorflow.py +++ b/language_learning/python/2019.10.11_neutral_network_with_tensorflow/neutral_network_with_tensorflow.py @@ -1,4 +1,11 @@ -import tensorflow as tf +""" +This code is supported by the website: https://www.guanjihuan.com +The newest version of this code is on the web page: https://www.guanjihuan.com/archives/154 +""" + +# import tensorflow as tf +import tensorflow.compat.v1 as tf # 之所以这么调用,是因为tensorflow版本2.0无法兼容版本1.0 +tf.compat.v1.disable_eager_execution() # 这行代码可以保证 sess.run() 能够正常运行 import numpy as np import matplotlib.pyplot as plt @@ -63,7 +70,7 @@ plt.show() # 保存训练好的神经网络模型tf.train.Saver() saver = tf.train.Saver() -save_path = saver.save(sess, "my_net/save_net.ckpt") # 保存模型 +save_path = saver.save(sess, "./my_net/save_net.ckpt") # 保存模型 print("Save to path: ", save_path) print() sess.close() # 关闭会话 @@ -71,7 +78,7 @@ sess.close() # 关闭会话 # 调用神经网络模型,来预测新的值 with tf.Session() as sess2: - saver.restore(sess2, "my_net/save_net.ckpt") # 提取模型中的所有变量 + saver.restore(sess2, "./my_net/save_net.ckpt") # 提取模型中的所有变量 print(y_data[0, :]) # 输出的原始值 print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值