0.1.96
This commit is contained in:
parent
e7208f38c0
commit
194b90a0d8
@ -1,7 +1,7 @@
|
||||
[metadata]
|
||||
# replace with your username:
|
||||
name = guan
|
||||
version = 0.1.95
|
||||
version = 0.1.96
|
||||
author = guanjihuan
|
||||
author_email = guanjihuan@163.com
|
||||
description = An open source python package
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: guan
|
||||
Version: 0.1.95
|
||||
Version: 0.1.96
|
||||
Summary: An open source python package
|
||||
Home-page: https://py.guanjihuan.com
|
||||
Author: guanjihuan
|
||||
|
@ -221,6 +221,29 @@ def load_train_data(x_train, y_train, batch_size=32):
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
return train_loader
|
||||
|
||||
# 从pickle文件中读取输入数据和输出数据,用于训练或预测
|
||||
def load_input_data_and_output_data_as_torch_tensors_with_pickle(index_range=[1, 2, 3], directory='./', input_filename='input_index=', output_filename='output_index=', type=None):
|
||||
import guan
|
||||
import numpy as np
|
||||
import torch
|
||||
input_data = []
|
||||
for index in index_range:
|
||||
input = guan.load_data(filename=directory+input_filename+str(index))
|
||||
input_data.append(input)
|
||||
output_data = []
|
||||
for index in index_range:
|
||||
output = guan.load_data(filename=directory+output_filename+str(index))
|
||||
output_data.append(output)
|
||||
if type == None:
|
||||
input_data = np.array(input_data)
|
||||
output_data= np.array(output_data)
|
||||
else:
|
||||
input_data = np.array(input_data).astype(type)
|
||||
output_data= np.array(output_data).astype(type)
|
||||
input_data = torch.from_numpy(input_data)
|
||||
output_data = torch.from_numpy(output_data)
|
||||
return input_data, output_data
|
||||
|
||||
# 数据的主成分分析PCA
|
||||
def pca_of_data(data, n_components=None, standard=1):
|
||||
from sklearn.decomposition import PCA
|
||||
|
Loading…
x
Reference in New Issue
Block a user