This commit is contained in:
guanjihuan 2024-06-03 16:18:33 +08:00
parent cd876316e2
commit fe2d947179
3 changed files with 45 additions and 0 deletions

View File

@ -0,0 +1,14 @@
from torch.utils.data import DataLoader, TensorDataset
import torch
for i0 in range(5):
x_train = torch.randn(100, 20) # 小文件加载
y_train = torch.randn(100, 1) # 小文件加载
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
for batch_x, batch_y in train_loader:
print(batch_x.shape)
print(batch_y.shape)
if i0 == 0:
print('Training model...')
else:
print('Continue training model...')

View File

@ -0,0 +1,9 @@
from torch.utils.data import DataLoader, TensorDataset
import torch
x_train = torch.randn(500, 20) # 全部数据加载
y_train = torch.randn(500, 1) # 全部数据加载
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
for batch_x, batch_y in train_loader:
print(batch_x.shape)
print(batch_y.shape)

View File

@ -0,0 +1,22 @@
import torch
def load_data_from_file(filename):
x_train = torch.randn(100, 20) # 小文件加载
y_train = torch.randn(100, 1) # 小文件加载
return x_train, y_train
x_train_new = [] # 新的小文件
y_train_new = [] # 新的小文件
n = 5
for i0_new in range(n):
for i0 in range(n):
x_train, y_train = load_data_from_file(filename=str(i0))
if i0 == 0:
x_train_new = x_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]
y_train_new = y_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]
else:
x_train_new = torch.cat((x_train_new, x_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]), dim=0)
y_train_new = torch.cat((y_train_new, y_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]), dim=0)
print(x_train_new.shape)
print(y_train_new.shape)
print('Save new file!')