This commit is contained in:
guanjihuan 2022-09-20 16:53:06 +08:00
parent 191620388e
commit 8fc9fb183e
6 changed files with 821 additions and 821 deletions

View File

@ -1,77 +1,77 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def add_layer(inputs, in_size, out_size, activation_function=None): # 定义一层的所有神经元 def add_layer(inputs, in_size, out_size, activation_function=None): # 定义一层的所有神经元
Weights = tf.Variable(tf.random_normal([in_size, out_size])) # 定义Weights为tf变量并给予初值 Weights = tf.Variable(tf.random_normal([in_size, out_size])) # 定义Weights为tf变量并给予初值
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) # 定义biases为tf变量并给予初值 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) # 定义biases为tf变量并给予初值
Wx_plus_b = tf.matmul(inputs, Weights) + biases # 得分 Wx_plus_b = tf.matmul(inputs, Weights) + biases # 得分
if activation_function is None: # 没有激活函数 if activation_function is None: # 没有激活函数
outputs = Wx_plus_b outputs = Wx_plus_b
else: else:
outputs = activation_function(Wx_plus_b) # 使用激活函数 outputs = activation_function(Wx_plus_b) # 使用激活函数
return outputs # 返回该层每个神经元的输出值维度为out_size return outputs # 返回该层每个神经元的输出值维度为out_size
# 产生训练的数据 # 产生训练的数据
x_data = np.linspace(-1, 1, 300, dtype=np.float32)[:, np.newaxis] # 产生数据,作为神经网络的输入数据。注:[:, np.newaxis]是用来增加一个轴,变成一个矩阵。 x_data = np.linspace(-1, 1, 300, dtype=np.float32)[:, np.newaxis] # 产生数据,作为神经网络的输入数据。注:[:, np.newaxis]是用来增加一个轴,变成一个矩阵。
noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32) # 产生噪声 noise = np.random.normal(0, 0.05, x_data.shape).astype(np.float32) # 产生噪声
y_data = np.square(x_data) - 0.5 + noise # x_data加上噪声作为神经网络的输出数据。 y_data = np.square(x_data) - 0.5 + noise # x_data加上噪声作为神经网络的输出数据。
print(x_data.shape) # 查看数据维度 print(x_data.shape) # 查看数据维度
print(noise.shape) # 查看数据维度 print(noise.shape) # 查看数据维度
print(y_data.shape) # 查看数据维度 print(y_data.shape) # 查看数据维度
print() # 打印输出空一行 print() # 打印输出空一行
# 神经网络模型的建立 # 神经网络模型的建立
xs = tf.placeholder(tf.float32, [None, 1]) # 定义占位符为神经网络训练的输入数据。这里的None代表无论输入有多少数据都可以 xs = tf.placeholder(tf.float32, [None, 1]) # 定义占位符为神经网络训练的输入数据。这里的None代表无论输入有多少数据都可以
ys = tf.placeholder(tf.float32, [None, 1]) # 定义占位符,为神经网络训练的输出数据。 ys = tf.placeholder(tf.float32, [None, 1]) # 定义占位符,为神经网络训练的输出数据。
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu) # 增加一个隐藏层 l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu) # 增加一个隐藏层
prediction = add_layer(l1, 10, 1, activation_function=None) # 输出层 prediction = add_layer(l1, 10, 1, activation_function=None) # 输出层
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1])) # 损失函数 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1])) # 损失函数
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 梯度下降 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 梯度下降
init = tf.global_variables_initializer() # 变量初始化 init = tf.global_variables_initializer() # 变量初始化
# 画出原始的输入输出数据点图 # 画出原始的输入输出数据点图
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(1, 1, 1) ax = fig.add_subplot(1, 1, 1)
ax.scatter(x_data, y_data) ax.scatter(x_data, y_data)
plt.ion() # 开启交互模式 plt.ion() # 开启交互模式
plt.show() # 显示图像 plt.show() # 显示图像
# 训练神经网络模型 # 训练神经网络模型
sess = tf.Session() # 启动一个会话 sess = tf.Session() # 启动一个会话
sess.run(init) # 初始化变量 sess.run(init) # 初始化变量
for i in range(1000): # 训练1000次 for i in range(1000): # 训练1000次
sess.run(train_step, feed_dict={xs: x_data, ys: y_data}) # 喂数据梯度下降循环1000次。 sess.run(train_step, feed_dict={xs: x_data, ys: y_data}) # 喂数据梯度下降循环1000次。
if i % 50 == 0: # 每训练50次画一下图 if i % 50 == 0: # 每训练50次画一下图
try: # to visualize the result and improvement try: # to visualize the result and improvement
ax.lines.remove(lines[0]) ax.lines.remove(lines[0])
except Exception: except Exception:
pass pass
prediction_value = sess.run(prediction, feed_dict={xs: x_data}) # 神经网络预测的值 prediction_value = sess.run(prediction, feed_dict={xs: x_data}) # 神经网络预测的值
print('loss=', sess.run(loss, feed_dict={xs: x_data, ys: y_data})) # 打印输出,查看损失函数下降情况 print('loss=', sess.run(loss, feed_dict={xs: x_data, ys: y_data})) # 打印输出,查看损失函数下降情况
print('prediction=', sess.run(prediction, feed_dict={xs: [x_data[0, :]]})) # # 打印输出神经网络预测的值 print('prediction=', sess.run(prediction, feed_dict={xs: [x_data[0, :]]})) # # 打印输出神经网络预测的值
print() # 打印空一行 print() # 打印空一行
lines = ax.plot(x_data, prediction_value, 'r-', lw=5) # 画出预测的值,用线连起来 lines = ax.plot(x_data, prediction_value, 'r-', lw=5) # 画出预测的值,用线连起来
plt.pause(.1) # 暂停0.1,防止画图过快看不清。 plt.pause(.1) # 暂停0.1,防止画图过快看不清。
plt.ioff() # 关闭交互模式,再画一次图。作用是不让图自动关掉。 plt.ioff() # 关闭交互模式,再画一次图。作用是不让图自动关掉。
lines = ax.plot(x_data, prediction_value, 'r-', lw=5) lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
plt.show() plt.show()
# 保存训练好的神经网络模型tf.train.Saver() # 保存训练好的神经网络模型tf.train.Saver()
saver = tf.train.Saver() saver = tf.train.Saver()
save_path = saver.save(sess, "my_net/save_net.ckpt") # 保存模型 save_path = saver.save(sess, "my_net/save_net.ckpt") # 保存模型
print("Save to path: ", save_path) print("Save to path: ", save_path)
print() print()
sess.close() # 关闭会话 sess.close() # 关闭会话
# 调用神经网络模型,来预测新的值 # 调用神经网络模型,来预测新的值
with tf.Session() as sess2: with tf.Session() as sess2:
saver.restore(sess2, "my_net/save_net.ckpt") # 提取模型中的所有变量 saver.restore(sess2, "my_net/save_net.ckpt") # 提取模型中的所有变量
print(y_data[0, :]) # 输出的原始值 print(y_data[0, :]) # 输出的原始值
print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值 print(sess2.run(prediction, feed_dict={xs: [x_data[0, :]]})) # 预测值

View File

@ -1,39 +1,39 @@
import tensorflow as tf # 导入tensorflow import tensorflow as tf # 导入tensorflow
greeting = tf.constant('Hello Google Tensorflow!') # 定义一个常量 greeting = tf.constant('Hello Google Tensorflow!') # 定义一个常量
# 第一种方式 # 第一种方式
sess = tf.Session() # 启动一个会话 sess = tf.Session() # 启动一个会话
result = sess.run(greeting) # 使用会话执行greeting计算模块 result = sess.run(greeting) # 使用会话执行greeting计算模块
print(result) # 打印显示 print(result) # 打印显示
sess.close() # 关闭会话 sess.close() # 关闭会话
# 第二种方式 # 第二种方式
with tf.Session() as sess: # 启动一个会话 with tf.Session() as sess: # 启动一个会话
print(sess.run(greeting)) # 打印显示 print(sess.run(greeting)) # 打印显示
# 例子1 # 例子1
matrix1 = tf.constant([[1., 3.]]) # 定义常数矩阵1 tf.constant() matrix1 = tf.constant([[1., 3.]]) # 定义常数矩阵1 tf.constant()
matrix2 = tf.constant([[2.], [2.]]) # 定义常数矩阵2 tf.constant() matrix2 = tf.constant([[2.], [2.]]) # 定义常数矩阵2 tf.constant()
product = tf.matmul(matrix1, matrix2) # 矩阵乘积 tf.matmul() product = tf.matmul(matrix1, matrix2) # 矩阵乘积 tf.matmul()
linear = tf.add(product, tf.constant(2.)) # 矩阵乘积后再加上一个常数 tf.add() linear = tf.add(product, tf.constant(2.)) # 矩阵乘积后再加上一个常数 tf.add()
with tf.Session() as sess: # 启动一个会话 tf.Session() with tf.Session() as sess: # 启动一个会话 tf.Session()
print(sess.run(matrix1)) # 执行语句并打印显示 tf.Session().run print(sess.run(matrix1)) # 执行语句并打印显示 tf.Session().run
print(sess.run(linear)) # 执行语句并打印显示 tf.Session().run print(sess.run(linear)) # 执行语句并打印显示 tf.Session().run
print(linear) # 直接打印是不能看到计算结果的因为还未执行只是一个张量。这里打印显示的结果是Tensor("Add:0", shape=(1, 1), dtype=float32) print(linear) # 直接打印是不能看到计算结果的因为还未执行只是一个张量。这里打印显示的结果是Tensor("Add:0", shape=(1, 1), dtype=float32)
# 例子2变量tf.Variable() # 例子2变量tf.Variable()
state = tf.Variable(3, name='counter') # 变量tf.Variable state = tf.Variable(3, name='counter') # 变量tf.Variable
init = tf.global_variables_initializer() # 如果定义了变量,后面一定要有这个语句,用来初始化变量。 init = tf.global_variables_initializer() # 如果定义了变量,后面一定要有这个语句,用来初始化变量。
with tf.Session() as sess: with tf.Session() as sess:
sess.run(init) # 变量一定要初始化变量 sess.run(init) # 变量一定要初始化变量
print(sess.run(state)) # 执行语句并打印显示 print(sess.run(state)) # 执行语句并打印显示
# 例子3占位符tf.placeholder()用来临时占坑需要用feed_dict来传入数值。 # 例子3占位符tf.placeholder()用来临时占坑需要用feed_dict来传入数值。
x1 = tf.placeholder(tf.float32) x1 = tf.placeholder(tf.float32)
x2 = tf.placeholder(tf.float32) x2 = tf.placeholder(tf.float32)
y = x1 + x2 y = x1 + x2
with tf.Session() as sess: with tf.Session() as sess:
print(sess.run(y, feed_dict={x1: 7, x2: 2})) print(sess.run(y, feed_dict={x1: 7, x2: 2}))

View File

@ -1,359 +1,359 @@
""" """
This code is supported by the website: https://www.guanjihuan.com This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/703 The newest version of this code is on the web page: https://www.guanjihuan.com/archives/703
""" """
import pygame import pygame
import random import random
import math import math
import numpy as np import numpy as np
# 参数 # 参数
screen_width = 1500 # 屏幕宽度 screen_width = 1500 # 屏幕宽度
screen_height = 900 # 屏幕高度 screen_height = 900 # 屏幕高度
map_width = screen_width*4 # 地图的大小 map_width = screen_width*4 # 地图的大小
map_height = screen_height*4 # 地图的大小 map_height = screen_height*4 # 地图的大小
number_enemy = map_width*map_height/500000 # 敌人的数量 number_enemy = map_width*map_height/500000 # 敌人的数量
number_dots = map_width * map_height / 50 # 点点的数量 number_dots = map_width * map_height / 50 # 点点的数量
max_show_size = 100 # 球显示的最大半径(屏幕有限,球再增大时,改变的地图比例尺寸) max_show_size = 100 # 球显示的最大半径(屏幕有限,球再增大时,改变的地图比例尺寸)
my_value = 1000 # 我的初始值 my_value = 1000 # 我的初始值
enemy_value_low = 500 # 敌人的初始值(最低) enemy_value_low = 500 # 敌人的初始值(最低)
enemy_value_high = 1500 # 敌人的初始值(最高) enemy_value_high = 1500 # 敌人的初始值(最高)
dot_value = 30 # 点点的值(地上的豆豆/食物值) dot_value = 30 # 点点的值(地上的豆豆/食物值)
my_speed = 10 # 我的球运动的速度 my_speed = 10 # 我的球运动的速度
speed_up = 20 # 按下鼠标时加速 speed_up = 20 # 按下鼠标时加速
speed_enemy = 10 # 敌人球正常运动速度 speed_enemy = 10 # 敌人球正常运动速度
speed_enemy_anomaly = 20 # 敌人突然加速时的速度(速度异常时的速度) speed_enemy_anomaly = 20 # 敌人突然加速时的速度(速度异常时的速度)
anomaly_pro = 0.5 # 敌人加速的概率 anomaly_pro = 0.5 # 敌人加速的概率
change_pro = 0.05 # 敌人移动路径变化的概率也就是1/change_pro左右会变化一次 change_pro = 0.05 # 敌人移动路径变化的概率也就是1/change_pro左右会变化一次
eat_percent = 0.9 # 吃掉敌人的球按多少比例并入自己的体积1对应的是100% eat_percent = 0.9 # 吃掉敌人的球按多少比例并入自己的体积1对应的是100%
loss = 0.001 # 按比例减小体重此外越重的减少越多10万体积损失值为loss的一倍 loss = 0.001 # 按比例减小体重此外越重的减少越多10万体积损失值为loss的一倍
enemy_bigger_pro = 0.0005 # 敌人的值增加了我的球的值的enemy_bigger_rate倍的几率 enemy_bigger_pro = 0.0005 # 敌人的值增加了我的球的值的enemy_bigger_rate倍的几率
enemy_bigger_rate = 0.1 # 增加我的球的体积的enemy_bigger_rate倍 enemy_bigger_rate = 0.1 # 增加我的球的体积的enemy_bigger_rate倍
class Color(object): # 定义颜色的类 class Color(object): # 定义颜色的类
@classmethod # 加了这个可以不需要把实例化,能直接调用类的方法 @classmethod # 加了这个可以不需要把实例化,能直接调用类的方法
def random_color(cls): # cls, 即class表示可以通过类名直接调用 def random_color(cls): # cls, 即class表示可以通过类名直接调用
red = random.randint(0, 255) red = random.randint(0, 255)
green = random.randint(0, 255) green = random.randint(0, 255)
blue = random.randint(0, 255) blue = random.randint(0, 255)
return red, green, blue return red, green, blue
class Ball(object): # 定义球 class Ball(object): # 定义球
def __init__(self, x, y, sx, sy, color, value): # 初始化 def __init__(self, x, y, sx, sy, color, value): # 初始化
self.x = x # 球的地图位置参数 self.x = x # 球的地图位置参数
self.y = y self.y = y
self.sx = sx # 速度参数 self.sx = sx # 速度参数
self.sy = sy self.sy = sy
self.color = color # 颜色 self.color = color # 颜色
self.value = value # 球的值,也就是球的大小(不是显示的大小) self.value = value # 球的值,也就是球的大小(不是显示的大小)
self.is_alive = True # 球默认是存活状态 self.is_alive = True # 球默认是存活状态
class My_Ball(Ball): # 定义我的球继承了Ball类的方法 class My_Ball(Ball): # 定义我的球继承了Ball类的方法
def __init__(self, x, y, sx, sy, color, value): def __init__(self, x, y, sx, sy, color, value):
# 注意如果重写了__init__() 时实例化子类就不会调用父类已经定义的__init__() # 注意如果重写了__init__() 时实例化子类就不会调用父类已经定义的__init__()
# 如果子类不重写__init__()方法实例化子类后会自动调用父类的__init__()的方法 # 如果子类不重写__init__()方法实例化子类后会自动调用父类的__init__()的方法
# 如果子类重写__init__()方法又需要调用父类的方法则要使用super关键词。 # 如果子类重写__init__()方法又需要调用父类的方法则要使用super关键词。
super().__init__(x, y, sx, sy, color, value) # 调用父类Ball的初始化方法__init__() super().__init__(x, y, sx, sy, color, value) # 调用父类Ball的初始化方法__init__()
self.radius = int(self.value**0.5) # 我的球的半径不考虑系数pi self.radius = int(self.value**0.5) # 我的球的半径不考虑系数pi
if self.radius >= max_show_size: # 如果半径比规定的最大半径还大,则显示最大半径 if self.radius >= max_show_size: # 如果半径比规定的最大半径还大,则显示最大半径
self.show_radius = max_show_size # 我的球显示的半径 self.show_radius = max_show_size # 我的球显示的半径
else: else:
self.show_radius = self.radius # 如果半径没有超过规定最大的半径,则显示原来实际大小的半径 self.show_radius = self.radius # 如果半径没有超过规定最大的半径,则显示原来实际大小的半径
self.position_x = int(screen_width/2) # 把我的球固定在屏幕中间position_x是屏幕显示的位置 self.position_x = int(screen_width/2) # 把我的球固定在屏幕中间position_x是屏幕显示的位置
self.position_y = int(screen_height/2) # 把我的球固定在屏幕中间position_y是屏幕显示的位置 self.position_y = int(screen_height/2) # 把我的球固定在屏幕中间position_y是屏幕显示的位置
def draw(self, window): # 把我的球画出来 def draw(self, window): # 把我的球画出来
self.radius = int(self.value ** 0.5) # 这里重复上面的,因为除了初始化之后,还要更新 self.radius = int(self.value ** 0.5) # 这里重复上面的,因为除了初始化之后,还要更新
if self.radius >= max_show_size: if self.radius >= max_show_size:
self.show_radius = max_show_size self.show_radius = max_show_size
else: else:
self.show_radius = self.radius self.show_radius = self.radius
self.position_x = int(screen_width / 2) self.position_x = int(screen_width / 2)
self.position_y = int(screen_height / 2) self.position_y = int(screen_height / 2)
pygame.draw.circle(window, self.color, (self.position_x , self.position_y), self.show_radius) pygame.draw.circle(window, self.color, (self.position_x , self.position_y), self.show_radius)
def eat_ball(self, other): # 吃别的球(包括小点点和敌人) def eat_ball(self, other): # 吃别的球(包括小点点和敌人)
if self != other and self.is_alive and other.is_alive: # 如果other不是自身自身和对方也都是存活状态则执行下面动作 if self != other and self.is_alive and other.is_alive: # 如果other不是自身自身和对方也都是存活状态则执行下面动作
distance = ((self.position_x - other.position_x) ** 2 + (self.position_y - other.position_y) ** 2) ** 0.5 # 两个球之间的距离 distance = ((self.position_x - other.position_x) ** 2 + (self.position_y - other.position_y) ** 2) ** 0.5 # 两个球之间的距离
if distance < self.show_radius and (self.show_radius > other.show_radius or (self.show_radius == other.show_radius and self.value > other.value)): # 如果自身半径比别人大,而且两者距离小于自身半径,那么可以吃掉。 if distance < self.show_radius and (self.show_radius > other.show_radius or (self.show_radius == other.show_radius and self.value > other.value)): # 如果自身半径比别人大,而且两者距离小于自身半径,那么可以吃掉。
other.is_alive = False # 吃球(敌方已死) other.is_alive = False # 吃球(敌方已死)
self.value += other.value*eat_percent # 自己的值增大(体量增大) self.value += other.value*eat_percent # 自己的值增大(体量增大)
self.radius = int(self.value ** 0.5) # 计算出半径 self.radius = int(self.value ** 0.5) # 计算出半径
if self.radius >= max_show_size: # 我的球的显示半径 if self.radius >= max_show_size: # 我的球的显示半径
self.show_radius = max_show_size self.show_radius = max_show_size
else: else:
self.show_radius = self.radius self.show_radius = self.radius
def move(self): # 移动规则 def move(self): # 移动规则
self.x += self.sx # 地图位置加上速度 self.x += self.sx # 地图位置加上速度
self.y += self.sy self.y += self.sy
# 横向出界 # 横向出界
if self.x < 0: # 离开了地图左边 if self.x < 0: # 离开了地图左边
self.x = 0 self.x = 0
if self.x > map_width: # 离开了地图右边 if self.x > map_width: # 离开了地图右边
self.x = map_width self.x = map_width
# 纵向出界 # 纵向出界
if self.y <= 0: # 离开了地图下边 if self.y <= 0: # 离开了地图下边
self.y = 0 self.y = 0
if self.y >= map_height: # 离开了地图上边 if self.y >= map_height: # 离开了地图上边
self.y = map_height self.y = map_height
class Enemy_Ball(Ball): # 定义敌人的球继承了Ball类的方法 class Enemy_Ball(Ball): # 定义敌人的球继承了Ball类的方法
def __init__(self, x, y, sx, sy, color, value, host_ball): # 初始化带上host_ball也就是我的球 def __init__(self, x, y, sx, sy, color, value, host_ball): # 初始化带上host_ball也就是我的球
super().__init__(x, y, sx, sy, color, value) super().__init__(x, y, sx, sy, color, value)
self.host_ball = host_ball self.host_ball = host_ball
self.radius = int(self.value**0.5) self.radius = int(self.value**0.5)
if self.host_ball.radius >= max_show_size: # 如果我的球比规定的最大尺寸还大,则敌人的球显示的比例要减小 if self.host_ball.radius >= max_show_size: # 如果我的球比规定的最大尺寸还大,则敌人的球显示的比例要减小
self.show_radius = max(10, int(self.radius/(self.host_ball.radius/max_show_size))) # 敌人的球也不能太小最小半径为10 self.show_radius = max(10, int(self.radius/(self.host_ball.radius/max_show_size))) # 敌人的球也不能太小最小半径为10
self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int( self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int(
screen_width / 2) # 计算出敌人的球和我的球的相对位置,并且按比例减小 screen_width / 2) # 计算出敌人的球和我的球的相对位置,并且按比例减小
self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int( self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int(
screen_height / 2) # 计算出敌人的球和我的球的相对位置,并且按比例减小 screen_height / 2) # 计算出敌人的球和我的球的相对位置,并且按比例减小
else: else:
self.show_radius = self.radius # 正常显示 self.show_radius = self.radius # 正常显示
self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2) # 敌人和我的球的相对位置 self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2) # 敌人和我的球的相对位置
self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2) # 敌人和我的球的相对位置 self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2) # 敌人和我的球的相对位置
# 画出球 # 画出球
def draw(self, window): def draw(self, window):
self.radius = int(self.value ** 0.5) self.radius = int(self.value ** 0.5)
if self.host_ball.radius >= max_show_size: # 这边把初始化的内容再写一遍,因为敌人的球初始化之后还要根据我的球而动态改变 if self.host_ball.radius >= max_show_size: # 这边把初始化的内容再写一遍,因为敌人的球初始化之后还要根据我的球而动态改变
self.show_radius = max(10, int(self.radius/(self.host_ball.radius/max_show_size))) self.show_radius = max(10, int(self.radius/(self.host_ball.radius/max_show_size)))
self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int( self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int(
screen_width / 2) screen_width / 2)
self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int( self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int(
screen_height / 2) screen_height / 2)
else: else:
self.show_radius = self.radius self.show_radius = self.radius
self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2) self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2)
self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2) self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2)
pygame.draw.circle(window, self.color, (self.position_x, self.position_y), self.show_radius) pygame.draw.circle(window, self.color, (self.position_x, self.position_y), self.show_radius)
def eat_ball(self, other): def eat_ball(self, other):
if self != other and self.is_alive and other.is_alive: if self != other and self.is_alive and other.is_alive:
distance = ((self.position_x - other.position_x) ** 2 + (self.position_y - other.position_y) ** 2) ** 0.5 distance = ((self.position_x - other.position_x) ** 2 + (self.position_y - other.position_y) ** 2) ** 0.5
if distance < self.show_radius and (self.show_radius > other.show_radius or (self.show_radius == other.show_radius and self.value > other.value)): if distance < self.show_radius and (self.show_radius > other.show_radius or (self.show_radius == other.show_radius and self.value > other.value)):
other.is_alive = False # 吃球 other.is_alive = False # 吃球
self.value += other.value*eat_percent self.value += other.value*eat_percent
self.radius = int(self.value ** 0.5) self.radius = int(self.value ** 0.5)
def move(self): # 移动规则 def move(self): # 移动规则
self.x += self.sx # 地图位置加上速度 self.x += self.sx # 地图位置加上速度
self.y += self.sy self.y += self.sy
# 横向出界 # 横向出界
if self.x < 0: # 离开了地图左边 if self.x < 0: # 离开了地图左边
self.sx = -self.sx self.sx = -self.sx
self.x = 0 self.x = 0
if self.x > map_width: # 离开了地图右边 if self.x > map_width: # 离开了地图右边
self.sx = -self.sx self.sx = -self.sx
self.x = map_width self.x = map_width
# 纵向出界 # 纵向出界
if self.y <= 0: # 离开了地图下边 if self.y <= 0: # 离开了地图下边
self.sy = -self.sy self.sy = -self.sy
self.y = 0 self.y = 0
if self.y >= map_height: # 离开了地图上边 if self.y >= map_height: # 离开了地图上边
self.sy = -self.sy self.sy = -self.sy
self.y = map_height self.y = map_height
class Dot_Ball(Ball): # 定义地上的小点点供自己的球和敌人的球吃继承了Ball类的方法 class Dot_Ball(Ball): # 定义地上的小点点供自己的球和敌人的球吃继承了Ball类的方法
def __init__(self, x, y, sx, sy, color, value, host_ball): def __init__(self, x, y, sx, sy, color, value, host_ball):
super().__init__(x, y, sx, sy, color, value) super().__init__(x, y, sx, sy, color, value)
self.host_ball = host_ball self.host_ball = host_ball
self.radius = 8 # 初始小点点大小 self.radius = 8 # 初始小点点大小
if self.host_ball.radius >= max_show_size: if self.host_ball.radius >= max_show_size:
self.show_radius = max(3, int(self.radius/(self.host_ball.radius/max_show_size))) # 小点点显示也不能太小最小显示半径为3 self.show_radius = max(3, int(self.radius/(self.host_ball.radius/max_show_size))) # 小点点显示也不能太小最小显示半径为3
self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int( self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int(
screen_width / 2) screen_width / 2)
self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int( self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int(
screen_height / 2) screen_height / 2)
else: else:
self.show_radius = self.radius self.show_radius = self.radius
self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2) self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2)
self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2) self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2)
# 画出球 # 画出球
def draw(self, window): def draw(self, window):
if self.host_ball.radius >= max_show_size: # 这边把初始化的内容再写一遍,因为小点点初始化之后还要根据我的球而动态改变 if self.host_ball.radius >= max_show_size: # 这边把初始化的内容再写一遍,因为小点点初始化之后还要根据我的球而动态改变
self.show_radius = max(3, int(self.radius/(self.host_ball.radius/max_show_size))) self.show_radius = max(3, int(self.radius/(self.host_ball.radius/max_show_size)))
self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int( self.position_x = int((self.x - self.host_ball.x) / (self.host_ball.radius / max_show_size)) + int(
screen_width / 2) screen_width / 2)
self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int( self.position_y = int((self.y - self.host_ball.y) / (self.host_ball.radius / max_show_size)) + int(
screen_height / 2) screen_height / 2)
else: else:
self.show_radius = self.radius self.show_radius = self.radius
self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2) self.position_x = (self.x - self.host_ball.x) + int(screen_width / 2)
self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2) self.position_y = (self.y - self.host_ball.y) + int(screen_height / 2)
pygame.draw.circle(window, self.color, (self.position_x, self.position_y) , self.show_radius) pygame.draw.circle(window, self.color, (self.position_x, self.position_y) , self.show_radius)
def creat_my_ball(): # 产生我的球 def creat_my_ball(): # 产生我的球
x = random.randint(0, map_width) # 我的球在地图中的位置,随机生成 x = random.randint(0, map_width) # 我的球在地图中的位置,随机生成
y = random.randint(0, map_height) y = random.randint(0, map_height)
value = my_value # 我的球的初始值 value = my_value # 我的球的初始值
color = 255, 255, 255 # 我的球的颜色 color = 255, 255, 255 # 我的球的颜色
sx = 0 # 速度默认为0 sx = 0 # 速度默认为0
sy = 0 sy = 0
host_ball = My_Ball(x, y, sx, sy, color, value) # 调用My_Ball类 host_ball = My_Ball(x, y, sx, sy, color, value) # 调用My_Ball类
return host_ball # 返回我的球 return host_ball # 返回我的球
def auto_creat_ball(balls, host_ball): # 自动产生敌人的球 def auto_creat_ball(balls, host_ball): # 自动产生敌人的球
if len(balls) <= number_enemy: # 控制敌人的数量,如果个数够了,就不再生成 if len(balls) <= number_enemy: # 控制敌人的数量,如果个数够了,就不再生成
x = random.randint(0, map_width) # 敌人球在地图中的位置,随机生成 x = random.randint(0, map_width) # 敌人球在地图中的位置,随机生成
y = random.randint(0, map_height) y = random.randint(0, map_height)
value = random.randint(enemy_value_low, enemy_value_high) # 敌人的球初始值 value = random.randint(enemy_value_low, enemy_value_high) # 敌人的球初始值
sx = random.randint(-speed_enemy, speed_enemy) # 敌人的球移动速度 sx = random.randint(-speed_enemy, speed_enemy) # 敌人的球移动速度
i2 = random.randint(0, 1) # y的移动方向 i2 = random.randint(0, 1) # y的移动方向
if i2 == 0: if i2 == 0:
sy = int((speed_enemy**2 - sx**2) ** 0.5) sy = int((speed_enemy**2 - sx**2) ** 0.5)
else: else:
sy = -int((speed_enemy ** 2 - sx ** 2) ** 0.5) sy = -int((speed_enemy ** 2 - sx ** 2) ** 0.5)
color = Color.random_color() # 敌人的颜色随机生成 color = Color.random_color() # 敌人的颜色随机生成
enemy = Enemy_Ball(x, y, sx, sy, color, value, host_ball) enemy = Enemy_Ball(x, y, sx, sy, color, value, host_ball)
balls.append(enemy) balls.append(enemy)
def auto_creat_dots(dots, host_ball): # 自动生成点点 def auto_creat_dots(dots, host_ball): # 自动生成点点
if len(dots) <= number_dots: # 控制点点的数量 if len(dots) <= number_dots: # 控制点点的数量
x = random.randint(0, map_width) # 随机生成点点的位置 x = random.randint(0, map_width) # 随机生成点点的位置
y = random.randint(0, map_height) y = random.randint(0, map_height)
value = dot_value # 点点的值 value = dot_value # 点点的值
sx = 0 # 点点速度为0 sx = 0 # 点点速度为0
sy = 0 sy = 0
color = Color.random_color() # 颜色 color = Color.random_color() # 颜色
dot = Dot_Ball(x, y, sx, sy, color, value, host_ball) dot = Dot_Ball(x, y, sx, sy, color, value, host_ball)
dots.append(dot) dots.append(dot)
def control_my_ball(host_ball): # 控制我的球 def control_my_ball(host_ball): # 控制我的球
host_ball.move() host_ball.move()
host_ball.value = host_ball.value*(1-loss*host_ball.value/100000) host_ball.value = host_ball.value*(1-loss*host_ball.value/100000)
for event in pygame.event.get(): # 监控事件(鼠标移动) for event in pygame.event.get(): # 监控事件(鼠标移动)
# print(event) # print(event)
if event.type == pygame.MOUSEBUTTONDOWN: if event.type == pygame.MOUSEBUTTONDOWN:
pos = event.pos pos = event.pos
speed = speed_up speed = speed_up
elif event.type == pygame.MOUSEMOTION: elif event.type == pygame.MOUSEMOTION:
pos = event.pos pos = event.pos
if event.buttons[0] == 1: if event.buttons[0] == 1:
speed = speed_up speed = speed_up
if event.buttons[0] == 0: if event.buttons[0] == 0:
speed = my_speed speed = my_speed
elif event.type == pygame.MOUSEBUTTONUP: elif event.type == pygame.MOUSEBUTTONUP:
pos = event.pos pos = event.pos
speed = my_speed speed = my_speed
else: else:
pos = [screen_width/2, screen_height/2] pos = [screen_width/2, screen_height/2]
speed = my_speed speed = my_speed
if abs(pos[0] - screen_width/2) < 30 and abs(pos[1] - screen_height/2) < 30: if abs(pos[0] - screen_width/2) < 30 and abs(pos[1] - screen_height/2) < 30:
host_ball.sx = 0 host_ball.sx = 0
host_ball.sy = 0 host_ball.sy = 0
elif pos[0] > screen_width/2 and pos[1] >= screen_height/2: elif pos[0] > screen_width/2 and pos[1] >= screen_height/2:
angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2))) angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2)))
host_ball.sx = int(speed * math.cos(angle)) host_ball.sx = int(speed * math.cos(angle))
host_ball.sy = int(speed * math.sin(angle)) host_ball.sy = int(speed * math.sin(angle))
elif pos[0] > screen_width/2 and pos[1] < screen_height/2: elif pos[0] > screen_width/2 and pos[1] < screen_height/2:
angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2))) angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2)))
host_ball.sx = int(speed * math.cos(angle)) host_ball.sx = int(speed * math.cos(angle))
host_ball.sy = -int(speed * math.sin(angle)) host_ball.sy = -int(speed * math.sin(angle))
elif pos[0] < screen_width/2 and pos[1] >= screen_height/2: elif pos[0] < screen_width/2 and pos[1] >= screen_height/2:
angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2))) angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2)))
host_ball.sx = -int(speed * math.cos(angle)) host_ball.sx = -int(speed * math.cos(angle))
host_ball.sy = int(speed * math.sin(angle)) host_ball.sy = int(speed * math.sin(angle))
elif pos[0] < screen_width/2 and pos[1] < screen_height/2: elif pos[0] < screen_width/2 and pos[1] < screen_height/2:
angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2))) angle = abs(math.atan((pos[1] - screen_height/2) / (pos[0] - screen_width/2)))
host_ball.sx = -int(speed * math.cos(angle)) host_ball.sx = -int(speed * math.cos(angle))
host_ball.sy = -int(speed * math.sin(angle)) host_ball.sy = -int(speed * math.sin(angle))
elif pos[0] == screen_width/2: elif pos[0] == screen_width/2:
host_ball.sx = 0 host_ball.sx = 0
if pos[1] >= 0: if pos[1] >= 0:
host_ball.sy = speed host_ball.sy = speed
else: else:
host.ball.sy = -speed host.ball.sy = -speed
def enemy_move(balls, host_ball): # 敌人移动 def enemy_move(balls, host_ball): # 敌人移动
for enemy in balls: for enemy in balls:
enemy.move() # 移动 enemy.move() # 移动
enemy.value = enemy.value*(1-loss*enemy.value/100000) enemy.value = enemy.value*(1-loss*enemy.value/100000)
if random.randint(1, int(1/enemy_bigger_pro)) == 1: if random.randint(1, int(1/enemy_bigger_pro)) == 1:
enemy.value += host_ball.value*enemy_bigger_rate enemy.value += host_ball.value*enemy_bigger_rate
if random.randint(1, int(1/anomaly_pro)) == 1: if random.randint(1, int(1/anomaly_pro)) == 1:
speed_enemy0 = speed_enemy_anomaly # 敌人异常速度 speed_enemy0 = speed_enemy_anomaly # 敌人异常速度
else: else:
speed_enemy0 = speed_enemy # 敌人正常速度 speed_enemy0 = speed_enemy # 敌人正常速度
i = random.randint(1, int(1/change_pro)) # 一定的概率改变轨迹 i = random.randint(1, int(1/change_pro)) # 一定的概率改变轨迹
if i == 1: if i == 1:
enemy.sx = random.randint(-speed_enemy0, speed_enemy0) enemy.sx = random.randint(-speed_enemy0, speed_enemy0)
i2 = random.randint(0, 1) i2 = random.randint(0, 1)
if i2 == 0: if i2 == 0:
enemy.sy = int((speed_enemy0 ** 2 - enemy.sx ** 2) ** 0.5) enemy.sy = int((speed_enemy0 ** 2 - enemy.sx ** 2) ** 0.5)
else: else:
enemy.sy = -int((speed_enemy0 ** 2 - enemy.sx ** 2) ** 0.5) enemy.sy = -int((speed_enemy0 ** 2 - enemy.sx ** 2) ** 0.5)
def eat_each_other(host_ball, balls, dots): # 吃球 def eat_each_other(host_ball, balls, dots): # 吃球
for enemy in balls: for enemy in balls:
for enemy2 in balls: for enemy2 in balls:
enemy.eat_ball(enemy2) # 敌人互吃 enemy.eat_ball(enemy2) # 敌人互吃
for food in dots: for food in dots:
enemy.eat_ball(food) # 敌人吃点点 enemy.eat_ball(food) # 敌人吃点点
for enemy in balls: for enemy in balls:
host_ball.eat_ball(enemy) # 我吃敌人 host_ball.eat_ball(enemy) # 我吃敌人
enemy.eat_ball(host_ball) # 敌人吃我 enemy.eat_ball(host_ball) # 敌人吃我
for food in dots: for food in dots:
host_ball.eat_ball(food) # 我吃点点 host_ball.eat_ball(food) # 我吃点点
def paint(host_ball, balls, dots, screen): def paint(host_ball, balls, dots, screen):
screen.fill((0, 0, 0)) # 刷漆 screen.fill((0, 0, 0)) # 刷漆
if host_ball.is_alive: if host_ball.is_alive:
host_ball.draw(screen) host_ball.draw(screen)
for enemy in balls: # 遍历容器 for enemy in balls: # 遍历容器
if enemy.is_alive: if enemy.is_alive:
enemy.draw(screen) enemy.draw(screen)
else: else:
balls.remove(enemy) balls.remove(enemy)
for food in dots: # 遍历容器 for food in dots: # 遍历容器
if food.is_alive: if food.is_alive:
food.draw(screen) food.draw(screen)
else: else:
dots.remove(food) dots.remove(food)
def main(): def main():
pygame.init() # 初始化 pygame.init() # 初始化
screen = pygame.display.set_mode((screen_width, screen_height)) # 设置屏幕 screen = pygame.display.set_mode((screen_width, screen_height)) # 设置屏幕
pygame.display.set_caption("球球大作战") # 设置屏幕标题 pygame.display.set_caption("球球大作战") # 设置屏幕标题
balls = [] # 定义一容器 存放所有的敌方球 balls = [] # 定义一容器 存放所有的敌方球
dots = [] # 定义一容器 存放所有的点点 dots = [] # 定义一容器 存放所有的点点
is_running = True # 默认运行状态 is_running = True # 默认运行状态
host_ball = creat_my_ball() # 产生我的球 host_ball = creat_my_ball() # 产生我的球
i00 = 0 # 一个参数 i00 = 0 # 一个参数
while is_running: while is_running:
for event in pygame.event.get(): for event in pygame.event.get():
if event.type == pygame.QUIT: if event.type == pygame.QUIT:
is_running = False is_running = False
auto_creat_dots(dots, host_ball) # 自动生成点点 auto_creat_dots(dots, host_ball) # 自动生成点点
auto_creat_ball(balls, host_ball) # 自动生成敌人 auto_creat_ball(balls, host_ball) # 自动生成敌人
paint(host_ball, balls, dots, screen) # 把所有的都画出来 调用draw方法 paint(host_ball, balls, dots, screen) # 把所有的都画出来 调用draw方法
pygame.display.flip() # 渲染 pygame.display.flip() # 渲染
pygame.time.delay(30) # 设置动画的时间延迟 pygame.time.delay(30) # 设置动画的时间延迟
control_my_ball(host_ball) # 移动我的球 control_my_ball(host_ball) # 移动我的球
enemy_move(balls, host_ball) # 敌人的球随机运动 enemy_move(balls, host_ball) # 敌人的球随机运动
eat_each_other(host_ball, balls, dots) # 吃球 调用eat_ball方法 eat_each_other(host_ball, balls, dots) # 吃球 调用eat_ball方法
i00 += 1 i00 += 1
if np.mod(i00, 50) == 0: if np.mod(i00, 50) == 0:
print(host_ball.value) print(host_ball.value)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,106 +1,106 @@
""" """
This code is supported by the website: https://www.guanjihuan.com This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/706 The newest version of this code is on the web page: https://www.guanjihuan.com/archives/706
""" """
import numpy as np import numpy as np
import time import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import tushare as ts import tushare as ts
def main(): def main():
start_clock = time.perf_counter() start_clock = time.perf_counter()
pro = ts.pro_api('到官网上注册寻找Token填在这里!') pro = ts.pro_api('到官网上注册寻找Token填在这里!')
print('\n我的策略:见好就收,遇低抄底。\n' print('\n我的策略:见好就收,遇低抄底。\n'
' 【卖出】买入后涨了5%就卖出\n' ' 【卖出】买入后涨了5%就卖出\n'
' 【买入】卖出后跌了5%就买入\n' ' 【买入】卖出后跌了5%就买入\n'
'注:第一天必须买进,最后一天前必须卖出(为了与不操作的做对比)\n') '注:第一天必须买进,最后一天前必须卖出(为了与不操作的做对比)\n')
number = 1 number = 1
for i in range(number): for i in range(number):
data = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date') # 所有股票列表 data = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date') # 所有股票列表
# print(data.columns) # 查看该数据的表头 # print(data.columns) # 查看该数据的表头
# print(data) # 3688多行的股票数据 # print(data) # 3688多行的股票数据
i = 1 # 查看第二行数据“万科A”股 i = 1 # 查看第二行数据“万科A”股
ts_code = data.values[i, 0] # 股票代码 ts_code = data.values[i, 0] # 股票代码
stock = data.values[i, 2] # 股票名称 stock = data.values[i, 2] # 股票名称
industry = data.values[i, 4] # 属于哪个行业 industry = data.values[i, 4] # 属于哪个行业
start_date = '20110101' # 开始时间 start_date = '20110101' # 开始时间
end_date = '20191027' # 结束时间 end_date = '20191027' # 结束时间
df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date) # 查看该股票的日线数据 df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date) # 查看该股票的日线数据
# print(df.columns) # 查看该数据的表头 # print(df.columns) # 查看该数据的表头
# print(df) # 查看该股票的日线数据 # print(df) # 查看该股票的日线数据
close = np.array(list(reversed(df.values[:, 5]))) # 提取出收盘价,并按时间顺序排列,从过去到现在 close = np.array(list(reversed(df.values[:, 5]))) # 提取出收盘价,并按时间顺序排列,从过去到现在
pct_chg = np.array(list(reversed(df.values[:, 8]))) # 提取出涨跌幅,并按时间顺序排列,从过去到现在 pct_chg = np.array(list(reversed(df.values[:, 8]))) # 提取出涨跌幅,并按时间顺序排列,从过去到现在
# print(df.columns[5], '=', close, '\n') # 查看收盘价 # print(df.columns[5], '=', close, '\n') # 查看收盘价
# print(df.columns[8], '=', pct_chg, '\n') # 查看涨跌幅 # print(df.columns[8], '=', pct_chg, '\n') # 查看涨跌幅
profit, profit_no_operation, times, invest_money, buy_time_all, sell_time_all = back_test(close.shape[0], close, pct_chg) profit, profit_no_operation, times, invest_money, buy_time_all, sell_time_all = back_test(close.shape[0], close, pct_chg)
# 调用回测函数,返回了“利润,未操作的利润, 按该策略操作了几次, 总投资金额, 按该策略买的时间, 按该策略卖的时间”的值 # 调用回测函数,返回了“利润,未操作的利润, 按该策略操作了几次, 总投资金额, 按该策略买的时间, 按该策略卖的时间”的值
print('\n------股票:', stock, ts_code, industry, '[买入市值=%7.2f' % invest_money, ']------') print('\n------股票:', stock, ts_code, industry, '[买入市值=%7.2f' % invest_money, ']------')
print('回测时间段:', start_date, '-', end_date) print('回测时间段:', start_date, '-', end_date)
print('操作后利润= %6.2f' % profit, ' 买入(卖出)次数=', times, ' ') print('操作后利润= %6.2f' % profit, ' 买入(卖出)次数=', times, ' ')
print('不操作利润= %6.2f' % profit_no_operation, '(第一天买入,最后一天卖出,中间未操作)') print('不操作利润= %6.2f' % profit_no_operation, '(第一天买入,最后一天卖出,中间未操作)')
end_clock = time.perf_counter() end_clock = time.perf_counter()
print('CPU执行时间=', end_clock - start_clock, 's') print('CPU执行时间=', end_clock - start_clock, 's')
plt.figure(1) plt.figure(1)
plt.title('Stock Code: '+ts_code+' (red point: buy, green point: sell)') plt.title('Stock Code: '+ts_code+' (red point: buy, green point: sell)')
plt.grid() plt.grid()
plt.plot(range(close.shape[0]), close, '-') plt.plot(range(close.shape[0]), close, '-')
for i in buy_time_all: for i in buy_time_all:
plt.plot(i, close[int(i)], 'or', markersize=13) # 红色是买进的点 plt.plot(i, close[int(i)], 'or', markersize=13) # 红色是买进的点
for i in sell_time_all: for i in sell_time_all:
plt.plot(i, close[int(i)], 'dg', markersize=13) # 绿色是卖出的点 plt.plot(i, close[int(i)], 'dg', markersize=13) # 绿色是卖出的点
plt.show() plt.show()
def back_test(days, close, pct_chg, money_in=10000): # 定义该策略的回测效果(按旧数据检查该策略是否有效) def back_test(days, close, pct_chg, money_in=10000): # 定义该策略的回测效果(按旧数据检查该策略是否有效)
money_in_amount = int(money_in/close[0]) # 投资金额换算成股票股数 money_in_amount = int(money_in/close[0]) # 投资金额换算成股票股数
invest_money = close[0]*money_in_amount # 实际买了股票的金额 invest_money = close[0]*money_in_amount # 实际买了股票的金额
profit_no_operation = (close[close.shape[0]-1]-close[0])*money_in_amount # 不操作的利润 profit_no_operation = (close[close.shape[0]-1]-close[0])*money_in_amount # 不操作的利润
position = -1 # 买入还是卖出的状态,默认卖出 position = -1 # 买入还是卖出的状态,默认卖出
total_profit = 0 total_profit = 0
times = 0 times = 0
current_buy_pct = -999 current_buy_pct = -999
current_sell_pct = 999 current_sell_pct = 999
buy_time_all = np.array([]) buy_time_all = np.array([])
sell_time_all = np.array([]) sell_time_all = np.array([])
for i in range(days): # 总天数 for i in range(days): # 总天数
if i == 0: # 第一天,满仓买买买!为了和不操作的对比,第一天就要买入。 if i == 0: # 第一天,满仓买买买!为了和不操作的对比,第一天就要买入。
buy_time = i # 买入时间 buy_time = i # 买入时间
buy_time_all = np.append(buy_time_all, [buy_time], axis=0) # 买入时间存档 buy_time_all = np.append(buy_time_all, [buy_time], axis=0) # 买入时间存档
position = 1 # 标记为买入状态 position = 1 # 标记为买入状态
print('------------------第', buy_time, '天买进-------------') print('------------------第', buy_time, '天买进-------------')
else: else:
profit = 0 profit = 0
if position == 1: # 买入状态 if position == 1: # 买入状态
current_buy_pct = (close[i]-close[buy_time])/close[buy_time]*100 # 买入后的涨跌情况 current_buy_pct = (close[i]-close[buy_time])/close[buy_time]*100 # 买入后的涨跌情况
# print('当前买进后的涨跌情况:第', i, '天=', current_buy_pct) # print('当前买进后的涨跌情况:第', i, '天=', current_buy_pct)
if position == 0: # 卖出状态 if position == 0: # 卖出状态
current_sell_pct = (close[i]-close[sell_time])/close[sell_time]*100 # 卖出后的涨跌情况 current_sell_pct = (close[i]-close[sell_time])/close[sell_time]*100 # 卖出后的涨跌情况
if current_sell_pct < -5 and position == 0: # 卖出状态且卖出后跌了有3%,这时候买入 if current_sell_pct < -5 and position == 0: # 卖出状态且卖出后跌了有3%,这时候买入
buy_time = i # 买入时间 buy_time = i # 买入时间
buy_time_all = np.append(buy_time_all, [buy_time], axis=0) # 买入时间存档 buy_time_all = np.append(buy_time_all, [buy_time], axis=0) # 买入时间存档
print('------------------第', buy_time, '天买进-------------') print('------------------第', buy_time, '天买进-------------')
position = 1 # 标记为买入状态 position = 1 # 标记为买入状态
continue continue
if current_buy_pct > 5 and position == 1: # 买入状态且买入后涨了有3%,这时候卖出 if current_buy_pct > 5 and position == 1: # 买入状态且买入后涨了有3%,这时候卖出
sell_time = i # 卖出时间 sell_time = i # 卖出时间
sell_time_all = np.append(sell_time_all, [sell_time], axis=0) # 卖出时间存档 sell_time_all = np.append(sell_time_all, [sell_time], axis=0) # 卖出时间存档
print('----------第', sell_time, '天卖出,持有天数:', sell_time-buy_time, '--------------\n') print('----------第', sell_time, '天卖出,持有天数:', sell_time-buy_time, '--------------\n')
position = 0 # 标记为卖出状态 position = 0 # 标记为卖出状态
profit = close[sell_time]-close[buy_time] # 赚取利润 profit = close[sell_time]-close[buy_time] # 赚取利润
times = times + 1 # 买入卖出次数加1 times = times + 1 # 买入卖出次数加1
total_profit = total_profit + profit*money_in_amount # 计算总利润 total_profit = total_profit + profit*money_in_amount # 计算总利润
if position == 1: # 最后一天如果是买入状态,则卖出 if position == 1: # 最后一天如果是买入状态,则卖出
profit = close[i]-close[buy_time] # 赚取利润 profit = close[i]-close[buy_time] # 赚取利润
total_profit = total_profit + profit # 计算总利润 total_profit = total_profit + profit # 计算总利润
times = times + 1 # 买入卖出次数加1 times = times + 1 # 买入卖出次数加1
print('--------------第', i, '天(最后一天)卖出,持有天数:', sell_time-buy_time, '--------------\n') print('--------------第', i, '天(最后一天)卖出,持有天数:', sell_time-buy_time, '--------------\n')
sell_time_all = np.append(sell_time_all, [i], axis=0) # 卖出时间存档 sell_time_all = np.append(sell_time_all, [i], axis=0) # 卖出时间存档
return total_profit, profit_no_operation, times, invest_money, buy_time_all, sell_time_all return total_profit, profit_no_operation, times, invest_money, buy_time_all, sell_time_all
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,164 +1,164 @@
! This code is supported by the website: https://www.guanjihuan.com ! This code is supported by the website: https://www.guanjihuan.com
! The newest version of this code is on the web page: https://www.guanjihuan.com/archives/762 ! The newest version of this code is on the web page: https://www.guanjihuan.com/archives/762
module global ! module是用来封装程序模块的module里use调用即可 module global ! module是用来封装程序模块的module里use调用即可
implicit none implicit none
double precision sqrt3,Pi double precision sqrt3,Pi
parameter(sqrt3=1.7320508075688773d0,Pi=3.14159265358979324d0) ! parameter代表不能改的常数 parameter(sqrt3=1.7320508075688773d0,Pi=3.14159265358979324d0) ! parameter代表不能改的常数
end module global end module global
program main !program开始,end program结束Fortran里不区分大小写! program main !program开始,end program结束Fortran里不区分大小写!
use global use global
use f95_precision ! use f95_precision !
use blas95 ! gemm() use blas95 ! gemm()
use lapack95 !GETRF,GETRI和求本征矢和本征矢的GEEV等 use lapack95 !GETRF,GETRI和求本征矢和本征矢的GEEV等
implicit none ! implicit是用来设置默认类型implicit none是关闭默认类型功能 implicit none ! implicit是用来设置默认类型implicit none是关闭默认类型功能
integer i,j,info,index1(2) ! integer i,j,info,index1(2) !
double precision a(2,2),b(2,2),c(2,2),& ! && double precision a(2,2),b(2,2),c(2,2),& ! &&
x1, x2, result_1, result_2, fun1 ! x1, x2, result_1, result_2, fun1 !
complex*16 dd(2,2), eigenvalues(2) ! complex*16 dd(2,2), eigenvalues(2) !
complex*16, allocatable:: eigenvectors(:,:) ! ! :: complex*16, allocatable:: eigenvectors(:,:) ! ! ::
character(len=15) hello, number ! ,len是规定长度 character(len=15) hello, number ! ,len是规定长度
allocate(eigenvectors(2,2)) ! allocate(eigenvectors(2,2)) !
write(*,*) '----输出----' write(*,*) '----输出----'
hello='hello world' hello='hello world'
write(*,*) hello ! ** write(*,*) hello ! **
write(number,'(f7.3)') pi ! write可以把数字类型转成字符类型'(f7.3)'*'(i3)' write(number,'(f7.3)') pi ! write可以把数字类型转成字符类型'(f7.3)'*'(i3)'
write(*,*) '数字转成字符串后再输出:', number write(*,*) '数字转成字符串后再输出:', number
write(*,"(a,18x)",advance="no") hello ! advance='no'advance的时候'(a)'a15或者其他'(10x)' write(*,"(a,18x)",advance="no") hello ! advance='no'advance的时候'(a)'a15或者其他'(10x)'
write(*,*) number,'这是不换行输出测试' write(*,*) number,'这是不换行输出测试'
write(*,"('一些固定文字也可以写在这里面', a, a,//)") hello, number !"()" write(*,"('一些固定文字也可以写在这里面', a, a,//)") hello, number !"()"
!'(a)'a15或者其他'(/)' !'(a)'a15或者其他'(/)'
write(*,*) '----写入文件----' write(*,*) '----写入文件----'
open(unit=10,file='learn-fortran-test.txt') ! open open(unit=10,file='learn-fortran-test.txt') ! open
write(10,*) hello, number write(10,*) hello, number
close(10) ! close close(10) ! close
write(*,*) '' write(*,*) ''
write(*,*) '----矩阵乘积----' write(*,*) '----矩阵乘积----'
a(1,1)=2;a(1,2)=5;a(2,1)=3;a(2,2)=2 ! a(1,1)=2;a(1,2)=5;a(2,1)=3;a(2,2)=2 !
b(1,1)=3;b(2,2)=3 b(1,1)=3;b(2,2)=3
write(*,*) '矩阵直接默认输出,是按列的顺序一个个输出' write(*,*) '矩阵直接默认输出,是按列的顺序一个个输出'
write(*,*) 'a=' write(*,*) 'a='
write(*,*) a write(*,*) a
write(*,*) '矩阵格式化输出' write(*,*) '矩阵格式化输出'
write(*,*) 'a=' write(*,*) 'a='
do i=1,2 do i=1,2
do j=1,2 do j=1,2
write(*,'(f10.4)',advance='no') a(i,j) ! write(*,'(f10.4)',advance='no') a(i,j) !
enddo enddo
write(*,*) '' write(*,*) ''
enddo enddo
write(*,*) 'b=' write(*,*) 'b='
do i=1,2 do i=1,2
do j=1,2 do j=1,2
write(*,'(f10.4)',advance='no') b(i,j) ! write(*,'(f10.4)',advance='no') b(i,j) !
enddo enddo
write(*,*) '' write(*,*) ''
enddo enddo
call gemm(a,b,c) ! call gemm() call gemm(a,b,c) ! call gemm()
write(*,*) '矩阵乘积c=a*b=' write(*,*) '矩阵乘积c=a*b='
do i=1,2 do i=1,2
do j=1,2 do j=1,2
write(*,'(f10.4)',advance='no') c(i,j) ! write(*,'(f10.4)',advance='no') c(i,j) !
enddo enddo
write(*,*) '' write(*,*) ''
enddo enddo
write(*,*) '' write(*,*) ''
write(*,*) '----矩阵求逆----' write(*,*) '----矩阵求逆----'
call getrf(a,index1,info); call getri(a,index1,info) !getrf和getri要配合起来使用求逆 call getrf(a,index1,info); call getri(a,index1,info) !getrf和getri要配合起来使用求逆
! info是需定义为整型If info = 0, the execution is successful. ! info是需定义为整型If info = 0, the execution is successful.
! index1是在getrf产生getri里输入index1也是需要定义为整型 ! index1是在getrf产生getri里输入index1也是需要定义为整型
! a不再是原来的矩阵了 ! a不再是原来的矩阵了
do i=1,2 do i=1,2
do j=1,2 do j=1,2
write(*,'(f10.4)',advance='no') a(i,j) ! write(*,'(f10.4)',advance='no') a(i,j) !
enddo enddo
write(*,*) '' write(*,*) ''
enddo enddo
write(*,*) '----复数矩阵----' write(*,*) '----复数矩阵----'
dd(1,1)=(1.d0, 0.d0) dd(1,1)=(1.d0, 0.d0)
dd(1,2)=(7.d0, 0.d0) dd(1,2)=(7.d0, 0.d0)
dd(2,1)=(3.d0, 0.d0) dd(2,1)=(3.d0, 0.d0)
dd(2,2)=(2.d0, 0.d0) dd(2,2)=(2.d0, 0.d0)
do i=1,2 do i=1,2
do j=1,2 do j=1,2
write(*,"(f10.4, '+1i*',f7.4)",advance='no') dd(i,j) ! write(*,"(f10.4, '+1i*',f7.4)",advance='no') dd(i,j) !
enddo enddo
write(*,*) '' write(*,*) ''
enddo enddo
write(*,*) '' write(*,*) ''
write(*,*) '----矩阵本征矢和本征值----' write(*,*) '----矩阵本征矢和本征值----'
call geev(A=dd, W=eigenvalues, VR=eigenvectors, INFO=info) call geev(A=dd, W=eigenvalues, VR=eigenvectors, INFO=info)
! A矩阵最好用上复数W是本征值一维数组VR是本征矢二维数组INFO是整数 ! A矩阵最好用上复数W是本征值一维数组VR是本征矢二维数组INFO是整数
! dd的值会发生改变! ! dd的值会发生改变!
write(*,*) 'eigenvectors:' write(*,*) 'eigenvectors:'
do i=1,2 do i=1,2
do j=1,2 do j=1,2
write(*,"(f10.4, '+1i*',f7.4)",advance='no') eigenvectors(i,j) ! write(*,"(f10.4, '+1i*',f7.4)",advance='no') eigenvectors(i,j) !
enddo enddo
write(*,*) '' write(*,*) ''
enddo enddo
write(*,*) 'eigenvalues:' write(*,*) 'eigenvalues:'
do i=1,2 do i=1,2
write(*,"(f10.4, '+1i*',f7.4)",advance='no') eigenvalues(i) write(*,"(f10.4, '+1i*',f7.4)",advance='no') eigenvalues(i)
enddo enddo
write(*,*) '' write(*,*) ''
deallocate(eigenvectors) ! deallocate(eigenvectors) !
write(*,*) '' ! write(*,*) '' !
write(*,*) '----循环加判断----' write(*,*) '----循环加判断----'
do i=1,5 ! do到enddo do i=1,5 ! do到enddo
if (mod(i,2)==0) then ! if()then if (mod(i,2)==0) then ! if()then
write(*,*) '我是偶数', i write(*,*) '我是偶数', i
else if (i==3) then else if (i==3) then
write(*,*) '我是第3个数字也是奇数' write(*,*) '我是第3个数字也是奇数'
else else
write(*,*) '我是奇数', i write(*,*) '我是奇数', i
endif endif
enddo enddo
write(*,*) '' write(*,*) ''
call sub1(2.d0, 3.d0, result_1, result_2) ! 2.d02.0d022.0 call sub1(2.d0, 3.d0, result_1, result_2) ! 2.d02.0d022.0
write(*,*) '调用子程序,求和:',result_1 write(*,*) '调用子程序,求和:',result_1
write(*,*) '调用子程序,乘积:',result_2 write(*,*) '调用子程序,乘积:',result_2
write(*,*) '使用函数,返回减法结果:', fun1(2.d0, 3.d0) write(*,*) '使用函数,返回减法结果:', fun1(2.d0, 3.d0)
write(*,*) '' write(*,*) ''
end program end program
subroutine sub1(x1,x2,y1,y2) !call调用 subroutine sub1(x1,x2,y1,y2) !call调用
double precision,intent(in):: x1, x2 ! :: double precision,intent(in):: x1, x2 ! ::
double precision,intent(out):: y1, y2 double precision,intent(out):: y1, y2
! intent(in) intent(out) intent(inout) ! intent(in) intent(out) intent(inout)
! intent()intent(in) ! intent()intent(in)
y1=x1+x2 y1=x1+x2
y2=x1*x2 y2=x1*x2
end subroutine end subroutine
function fun1(x1,x2) ! subroutine function fun1(x1,x2) ! subroutine
double precision x1,x2,fun1 ! ( double precision x1,x2,fun1 ! (
fun1=x1-x2 ! fun1=x1-x2 !
return ! return也可以不写if配合用 return ! return也可以不写if配合用
end function ! end end function ! end

View File

@ -1,78 +1,78 @@
! This code is supported by the website: https://www.guanjihuan.com ! This code is supported by the website: https://www.guanjihuan.com
! The newest version of this code is on the web page: https://www.guanjihuan.com/archives/764 ! The newest version of this code is on the web page: https://www.guanjihuan.com/archives/764
program hello_open_mp program hello_open_mp
use omp_lib ! include 'omp_lib.h' use omp_lib ! include 'omp_lib.h'
integer mcpu,tid,total,N,i,j,loop integer mcpu,tid,total,N,i,j,loop
double precision starttime, endtime, time,result_0 double precision starttime, endtime, time,result_0
double precision, allocatable:: T(:) double precision, allocatable:: T(:)
N=5 ! do并行 N=5 ! do并行
loop=1000000000 !loop值 loop=1000000000 !loop值
allocate(T(N)) allocate(T(N))
!call OMP_SET_NUM_THREADS(2) !线 !call OMP_SET_NUM_THREADS(2) !线
total=OMP_GET_NUM_PROCS() ! total=OMP_GET_NUM_PROCS() !
print '(a,i2)', '计算机处理器数量:' , total !write(*,'(a,i2)') print '(a,i2)', '计算机处理器数量:' , total !write(*,'(a,i2)')
print '(a)', '-----在并行之前-----' print '(a)', '-----在并行之前-----'
tid=OMP_GET_THREAD_NUM() !线线 tid=OMP_GET_THREAD_NUM() !线线
mcpu=OMP_GET_NUM_THREADS() !线 mcpu=OMP_GET_NUM_THREADS() !线
print '(a,i2,a,i2)', '当前线程号:',tid,';总的线程数:', mcpu print '(a,i2,a,i2)', '当前线程号:',tid,';总的线程数:', mcpu
print * ! print * !
print'(a)','-----第一部分程序开始并行-----' print'(a)','-----第一部分程序开始并行-----'
!$OMP PARALLEL DEFAULT(PRIVATE) ! DEFAULT(PRIVATE) !$OMP PARALLEL DEFAULT(PRIVATE) ! DEFAULT(PRIVATE)
tid=OMP_GET_THREAD_NUM() !线线 tid=OMP_GET_THREAD_NUM() !线线
mcpu=OMP_GET_NUM_THREADS() !线 mcpu=OMP_GET_NUM_THREADS() !线
print '(a,i2,a,i2)', '当前线程号:',tid,';总的线程数:', mcpu print '(a,i2,a,i2)', '当前线程号:',tid,';总的线程数:', mcpu
!$OMP END PARALLEL !$OMP END PARALLEL
print * ! print * !
print'(a)','-----第二部分程序开始并行-----' print'(a)','-----第二部分程序开始并行-----'
starttime=OMP_GET_WTIME() ! starttime=OMP_GET_WTIME() !
!$OMP PARALLEL DO DEFAULT(PRIVATE) SHARED(T,N,loop) ! !$OMP PARALLEL DO DEFAULT(PRIVATE) SHARED(T,N,loop) !
do i=1,N !do循环体 do i=1,N !do循环体
result_0=0 result_0=0
tid=OMP_GET_THREAD_NUM() !线线 tid=OMP_GET_THREAD_NUM() !线线
mcpu=OMP_GET_NUM_THREADS() !线 mcpu=OMP_GET_NUM_THREADS() !线
do j=1,loop !~ do j=1,loop !~
result_0 = result_0+1 !~ result_0 = result_0+1 !~
enddo !~ enddo !~
T(i) = result_0-loop+i !线 T(i) = result_0-loop+i !线
!i代表各个循环的参数 !i代表各个循环的参数
print '(a,i2, a, f10.4,a,i2,a,i2 )', 'T(',i,')=', T(i) , ' 来源于线程号',tid,';总的线程数:', mcpu print '(a,i2, a, f10.4,a,i2,a,i2 )', 'T(',i,')=', T(i) , ' 来源于线程号',tid,';总的线程数:', mcpu
enddo enddo
!$OMP END PARALLEL DO ! !$OMP END PARALLEL DO !
endtime=OMP_GET_WTIME() ! endtime=OMP_GET_WTIME() !
time=endtime-starttime ! time=endtime-starttime !
print '(a, f13.5)' , '第二部分程序按并行计算所用的时间:', time print '(a, f13.5)' , '第二部分程序按并行计算所用的时间:', time
print * ! print * !
print'(a)','-----第二部分程序按串行的计算-----' print'(a)','-----第二部分程序按串行的计算-----'
starttime=OMP_GET_WTIME() ! starttime=OMP_GET_WTIME() !
do i=1,N do i=1,N
result_0=0 result_0=0
tid=OMP_GET_THREAD_NUM() !线线 tid=OMP_GET_THREAD_NUM() !线线
mcpu=OMP_GET_NUM_THREADS() !线 mcpu=OMP_GET_NUM_THREADS() !线
do j=1,loop do j=1,loop
result_0 = result_0+1 result_0 = result_0+1
enddo enddo
T(i) = result_0-loop+i T(i) = result_0-loop+i
print '(a,i2, a, f10.4,a,i2,a,i2 )', 'T(' ,i,')=', T(i) , ' 来源于线程号',tid,';总的线程数:', mcpu print '(a,i2, a, f10.4,a,i2,a,i2 )', 'T(' ,i,')=', T(i) , ' 来源于线程号',tid,';总的线程数:', mcpu
enddo enddo
endtime=OMP_GET_WTIME() ! endtime=OMP_GET_WTIME() !
time=endtime-starttime ! time=endtime-starttime !
print '(a, f13.5)' , '第二部分程序按串行计算所用的时间:', time print '(a, f13.5)' , '第二部分程序按串行计算所用的时间:', time
print * ! print * !
tid=OMP_GET_THREAD_NUM() !线线 tid=OMP_GET_THREAD_NUM() !线线
mcpu=OMP_GET_NUM_THREADS() !线 mcpu=OMP_GET_NUM_THREADS() !线
print '(a,i5,a,i5)', '当前线程号:',tid,';总的线程数:', mcpu print '(a,i5,a,i5)', '当前线程号:',tid,';总的线程数:', mcpu
print * ! print * !
end program hello_open_mp ! end, end program end program hello_open_mp ! end, end program