py.guanjihuan.com/PyPI/src/guan/figure_plotting.py
2024-07-12 01:39:42 +08:00

452 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Module: figure_plotting
# 导入plt, fig, ax
def import_plt_and_start_fig_ax(adjust_bottom=0.2, adjust_left=0.2, labelsize=20, fontfamily='Times New Roman'):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=adjust_bottom, left=adjust_left)
ax.grid()
ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
labels = ax.get_xticklabels() + ax.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
return plt, fig, ax
# 基于plt, fig, ax画图
def plot_without_starting_fig(plt, fig, ax, x_array, y_array, xlabel='x', ylabel='y', title='', fontsize=20, style='', y_min=None, y_max=None, linewidth=None, markersize=None, color=None, fontfamily='Times New Roman'):
if color==None:
ax.plot(x_array, y_array, style, linewidth=linewidth, markersize=markersize)
else:
ax.plot(x_array, y_array, style, linewidth=linewidth, markersize=markersize, color=color)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if y_min!=None or y_max!=None:
if y_min==None:
y_min=min(y_array)
if y_max==None:
y_max=max(y_array)
ax.set_ylim(y_min, y_max)
# 画图
def plot(x_array, y_array, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=20, show=1, save=0, filename='a', file_format='.jpg', dpi=300, style='', y_min=None, y_max=None, linewidth=None, markersize=None, adjust_bottom=0.2, adjust_left=0.2, fontfamily='Times New Roman'):
import guan
plt, fig, ax = guan.import_plt_and_start_fig_ax(adjust_bottom=adjust_bottom, adjust_left=adjust_left, labelsize=labelsize, fontfamily=fontfamily)
ax.plot(x_array, y_array, style, linewidth=linewidth, markersize=markersize)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if y_min!=None or y_max!=None:
if y_min==None:
y_min=min(y_array)
if y_max==None:
y_max=max(y_array)
ax.set_ylim(y_min, y_max)
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 一组横坐标数据,两组纵坐标数据画图
def plot_two_array(x_array, y1_array, y2_array, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=20, show=1, save=0, filename='a', file_format='.jpg', dpi=300, style_1='', style_2='', y_min=None, y_max=None, linewidth_1=None, linewidth_2=None, markersize_1=None, markersize_2=None, adjust_bottom=0.2, adjust_left=0.2, fontfamily='Times New Roman'):
import guan
plt, fig, ax = guan.import_plt_and_start_fig_ax(adjust_bottom=adjust_bottom, adjust_left=adjust_left, labelsize=labelsize, fontfamily=fontfamily)
ax.plot(x_array, y1_array, style_1, linewidth=linewidth_1, markersize=markersize_1)
ax.plot(x_array, y2_array, style_2, linewidth=linewidth_2, markersize=markersize_2)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if y_min!=None or y_max!=None:
if y_min==None:
y1_min=min(y1_array)
y2_min=min(y2_array)
y_min=min([y1_min, y2_min])
if y_max==None:
y1_max=max(y1_array)
y2_max=max(y2_array)
y_max=max([y1_max, y2_max])
ax.set_ylim(y_min, y_max)
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 两组横坐标数据,两组纵坐标数据画图
def plot_two_array_with_two_horizontal_array(x1_array, x2_array, y1_array, y2_array, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=20, show=1, save=0, filename='a', file_format='.jpg', dpi=300, style_1='', style_2='', y_min=None, y_max=None, linewidth_1=None, linewidth_2=None, markersize_1=None, markersize_2=None, adjust_bottom=0.2, adjust_left=0.2, fontfamily='Times New Roman'):
import guan
plt, fig, ax = guan.import_plt_and_start_fig_ax(adjust_bottom=adjust_bottom, adjust_left=adjust_left, labelsize=labelsize, fontfamily=fontfamily)
ax.plot(x1_array, y1_array, style_1, linewidth=linewidth_1, markersize=markersize_1)
ax.plot(x2_array, y2_array, style_2, linewidth=linewidth_2, markersize=markersize_2)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if y_min!=None or y_max!=None:
if y_min==None:
y1_min=min(y1_array)
y2_min=min(y2_array)
y_min=min([y1_min, y2_min])
if y_max==None:
y1_max=max(y1_array)
y2_max=max(y2_array)
y_max=max([y1_max, y2_max])
ax.set_ylim(y_min, y_max)
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 一组横坐标数据,三组纵坐标数据画图
def plot_three_array(x_array, y1_array, y2_array, y3_array, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=20, show=1, save=0, filename='a', file_format='.jpg', dpi=300, style_1='', style_2='', style_3='', y_min=None, y_max=None, linewidth_1=None, linewidth_2=None, linewidth_3=None,markersize_1=None, markersize_2=None, markersize_3=None, adjust_bottom=0.2, adjust_left=0.2, fontfamily='Times New Roman'):
import guan
plt, fig, ax = guan.import_plt_and_start_fig_ax(adjust_bottom=adjust_bottom, adjust_left=adjust_left, labelsize=labelsize, fontfamily=fontfamily)
ax.plot(x_array, y1_array, style_1, linewidth=linewidth_1, markersize=markersize_1)
ax.plot(x_array, y2_array, style_2, linewidth=linewidth_2, markersize=markersize_2)
ax.plot(x_array, y3_array, style_3, linewidth=linewidth_3, markersize=markersize_3)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if y_min!=None or y_max!=None:
if y_min==None:
y1_min=min(y1_array)
y2_min=min(y2_array)
y3_min=min(y3_array)
y_min=min([y1_min, y2_min, y3_min])
if y_max==None:
y1_max=max(y1_array)
y2_max=max(y2_array)
y3_max=max(y3_array)
y_max=max([y1_max, y2_max, y3_max])
ax.set_ylim(y_min, y_max)
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 三组横坐标数据,三组纵坐标数据画图
def plot_three_array_with_three_horizontal_array(x1_array, x2_array, x3_array, y1_array, y2_array, y3_array, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=20, show=1, save=0, filename='a', file_format='.jpg', dpi=300, style_1='', style_2='', style_3='', y_min=None, y_max=None, linewidth_1=None, linewidth_2=None, linewidth_3=None,markersize_1=None, markersize_2=None, markersize_3=None, adjust_bottom=0.2, adjust_left=0.2, fontfamily='Times New Roman'):
import guan
plt, fig, ax = guan.import_plt_and_start_fig_ax(adjust_bottom=adjust_bottom, adjust_left=adjust_left, labelsize=labelsize, fontfamily=fontfamily)
ax.plot(x1_array, y1_array, style_1, linewidth=linewidth_1, markersize=markersize_1)
ax.plot(x2_array, y2_array, style_2, linewidth=linewidth_2, markersize=markersize_2)
ax.plot(x3_array, y3_array, style_3, linewidth=linewidth_3, markersize=markersize_3)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if y_min!=None or y_max!=None:
if y_min==None:
y1_min=min(y1_array)
y2_min=min(y2_array)
y3_min=min(y3_array)
y_min=min([y1_min, y2_min, y3_min])
if y_max==None:
y1_max=max(y1_array)
y2_max=max(y2_array)
y3_max=max(y3_array)
y_max=max([y1_max, y2_max, y3_max])
ax.set_ylim(y_min, y_max)
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 画三维图
def plot_3d_surface(x_array, y_array, matrix, xlabel='x', ylabel='y', zlabel='z', title='', fontsize=20, labelsize=15, show=1, save=0, filename='a', file_format='.jpg', dpi=300, z_min=None, z_max=None, rcount=100, ccount=100, fontfamily='Times New Roman'):
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator
matrix = np.array(matrix)
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
plt.subplots_adjust(bottom=0.1, right=0.65)
x_array, y_array = np.meshgrid(x_array, y_array)
if len(matrix.shape) == 2:
surf = ax.plot_surface(x_array, y_array, matrix, rcount=rcount, ccount=ccount, cmap=cm.coolwarm, linewidth=0, antialiased=False)
elif len(matrix.shape) == 3:
for i0 in range(matrix.shape[2]):
surf = ax.plot_surface(x_array, y_array, matrix[:,:,i0], rcount=rcount, ccount=ccount, cmap=cm.coolwarm, linewidth=0, antialiased=False)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_zlabel(zlabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.set_zlabel(zlabel, fontsize=fontsize)
ax.zaxis.set_major_locator(LinearLocator(5))
ax.zaxis.set_major_formatter('{x:.2f}')
if z_min!=None or z_max!=None:
if z_min==None:
z_min=matrix.min()
if z_max==None:
z_max=matrix.max()
ax.set_zlim(z_min, z_max)
ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
labels = ax.get_xticklabels() + ax.get_yticklabels() + ax.get_zticklabels()
[label.set_fontname('Times New Roman') for label in labels]
cax = plt.axes([0.8, 0.1, 0.05, 0.8])
cbar = fig.colorbar(surf, cax=cax)
cbar.ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
for l in cbar.ax.yaxis.get_ticklabels():
l.set_family('Times New Roman')
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 画Contour图
def plot_contour(x_array, y_array, matrix, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=15, cmap='jet', levels=None, show=1, save=0, filename='a', file_format='.jpg', dpi=300, fontfamily='Times New Roman'):
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.2, right=0.75, left=0.2)
x_array, y_array = np.meshgrid(x_array, y_array)
contour = ax.contourf(x_array,y_array,matrix,cmap=cmap, levels=levels)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
labels = ax.get_xticklabels() + ax.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
cax = plt.axes([0.8, 0.2, 0.05, 0.68])
cbar = fig.colorbar(contour, cax=cax)
cbar.ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
for l in cbar.ax.yaxis.get_ticklabels():
l.set_family('Times New Roman')
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 画棋盘图/伪彩色图
def plot_pcolor(x_array, y_array, matrix, xlabel='x', ylabel='y', title='', fontsize=20, labelsize=15, cmap='jet', levels=None, show=1, save=0, filename='a', file_format='.jpg', dpi=300, fontfamily='Times New Roman'):
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.2, right=0.75, left=0.2)
x_array, y_array = np.meshgrid(x_array, y_array)
contour = ax.pcolor(x_array,y_array,matrix, cmap=cmap)
if fontfamily=='Times New Roman':
ax.set_title(title, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_xlabel(xlabel, fontsize=fontsize, fontfamily='Times New Roman')
ax.set_ylabel(ylabel, fontsize=fontsize, fontfamily='Times New Roman')
else:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
labels = ax.get_xticklabels() + ax.get_yticklabels()
[label.set_fontname('Times New Roman') for label in labels]
cax = plt.axes([0.8, 0.2, 0.05, 0.68])
cbar = fig.colorbar(contour, cax=cax)
cbar.ax.tick_params(labelsize=labelsize)
if fontfamily=='Times New Roman':
for l in cbar.ax.yaxis.get_ticklabels():
l.set_family('Times New Roman')
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
if show == 1:
plt.show()
plt.close('all')
# 基于plt, fig, ax通过坐标画点和线
def draw_dots_and_lines_without_starting_fig(plt, fig, ax, coordinate_array, draw_dots=1, draw_lines=1, max_distance=1, line_style='-k', linewidth=1, dot_style='ro', markersize=3):
import numpy as np
coordinate_array = np.array(coordinate_array)
if draw_lines==1:
for i1 in range(coordinate_array.shape[0]):
for i2 in range(coordinate_array.shape[0]):
if np.sqrt((coordinate_array[i1, 0] - coordinate_array[i2, 0])**2+(coordinate_array[i1, 1] - coordinate_array[i2, 1])**2) <= max_distance:
ax.plot([coordinate_array[i1, 0], coordinate_array[i2, 0]], [coordinate_array[i1, 1], coordinate_array[i2, 1]], line_style, linewidth=linewidth)
if draw_dots==1:
for i in range(coordinate_array.shape[0]):
ax.plot(coordinate_array[i, 0], coordinate_array[i, 1], dot_style, markersize=markersize)
# 通过坐标画点和线
def draw_dots_and_lines(coordinate_array, draw_dots=1, draw_lines=1, max_distance=1, line_style='-k', linewidth=1, dot_style='ro', markersize=3, show=1, save=0, filename='a', file_format='.eps', dpi=300):
import numpy as np
import matplotlib.pyplot as plt
coordinate_array = np.array(coordinate_array)
x_range = max(coordinate_array[:, 0])-min(coordinate_array[:, 0])
y_range = max(coordinate_array[:, 1])-min(coordinate_array[:, 1])
fig, ax = plt.subplots(figsize=(6*x_range/y_range,6))
ax.set_aspect('equal') # important code ensuring that the x and y axes have the same scale.
plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
plt.axis('off')
if draw_lines==1:
for i1 in range(coordinate_array.shape[0]):
for i2 in range(coordinate_array.shape[0]):
if np.sqrt((coordinate_array[i1, 0] - coordinate_array[i2, 0])**2+(coordinate_array[i1, 1] - coordinate_array[i2, 1])**2) <= max_distance:
ax.plot([coordinate_array[i1, 0], coordinate_array[i2, 0]], [coordinate_array[i1, 1], coordinate_array[i2, 1]], line_style, linewidth=linewidth)
if draw_dots==1:
for i in range(coordinate_array.shape[0]):
ax.plot(coordinate_array[i, 0], coordinate_array[i, 1], dot_style, markersize=markersize)
if show==1:
plt.show()
if save==1:
if file_format=='.eps':
plt.savefig(filename+file_format)
else:
plt.savefig(filename+file_format, dpi=dpi)
# 合并两个图片
def combine_two_images(image_path_array, figsize=(16,8), show=0, save=1, filename='a', file_format='.jpg', dpi=300):
import numpy as np
num = np.array(image_path_array).shape[0]
if num != 2:
print('Error: The number of images should be two!')
else:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig = plt.figure(figsize=figsize)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
image_1 = mpimg.imread(image_path_array[0])
image_2 = mpimg.imread(image_path_array[1])
ax1.imshow(image_1)
ax2.imshow(image_2)
ax1.axis('off')
ax2.axis('off')
if show == 1:
plt.show()
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
plt.close('all')
# 合并三个图片
def combine_three_images(image_path_array, figsize=(16,5), show=0, save=1, filename='a', file_format='.jpg', dpi=300):
import numpy as np
num = np.array(image_path_array).shape[0]
if num != 3:
print('Error: The number of images should be three!')
else:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig = plt.figure(figsize=figsize)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)
image_1 = mpimg.imread(image_path_array[0])
image_2 = mpimg.imread(image_path_array[1])
image_3 = mpimg.imread(image_path_array[2])
ax1.imshow(image_1)
ax2.imshow(image_2)
ax3.imshow(image_3)
ax1.axis('off')
ax2.axis('off')
ax3.axis('off')
if show == 1:
plt.show()
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
plt.close('all')
# 合并四个图片
def combine_four_images(image_path_array, figsize=(16,16), show=0, save=1, filename='a', file_format='.jpg', dpi=300):
import numpy as np
num = np.array(image_path_array).shape[0]
if num != 4:
print('Error: The number of images should be four!')
else:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
fig = plt.figure(figsize=figsize)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
image_1 = mpimg.imread(image_path_array[0])
image_2 = mpimg.imread(image_path_array[1])
image_3 = mpimg.imread(image_path_array[2])
image_4 = mpimg.imread(image_path_array[3])
ax1.imshow(image_1)
ax2.imshow(image_2)
ax3.imshow(image_3)
ax4.imshow(image_4)
ax1.axis('off')
ax2.axis('off')
ax3.axis('off')
ax4.axis('off')
if show == 1:
plt.show()
if save == 1:
plt.savefig(filename+file_format, dpi=dpi)
plt.close('all')
# 对某个目录中的txt文件批量读取和画图
def batch_reading_and_plotting(directory, xlabel='x', ylabel='y'):
import re
import os
import guan
for root, dirs, files in os.walk(directory):
for file in files:
if re.search('^txt.', file[::-1]):
filename = file[:-4]
x_array, y_array = guan.read_one_dimensional_data(filename=filename)
guan.plot(x_array, y_array, xlabel=xlabel, ylabel=ylabel, title=filename, show=0, save=1, filename=filename)
# 将图片制作GIF动画
def make_gif(image_path_array, filename='a', duration=0.1):
import imageio
images = []
for image_path in image_path_array:
im = imageio.imread(image_path)
images.append(im)
imageio.mimsave(filename+'.gif', images, 'GIF', duration=duration)
# 选取Matplotlib颜色
def color_matplotlib():
color_array = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
return color_array