update
This commit is contained in:
@@ -0,0 +1,14 @@
|
|||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
import torch
|
||||||
|
for i0 in range(5):
|
||||||
|
x_train = torch.randn(100, 20) # 小文件加载
|
||||||
|
y_train = torch.randn(100, 1) # 小文件加载
|
||||||
|
train_dataset = TensorDataset(x_train, y_train)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
for batch_x, batch_y in train_loader:
|
||||||
|
print(batch_x.shape)
|
||||||
|
print(batch_y.shape)
|
||||||
|
if i0 == 0:
|
||||||
|
print('Training model...')
|
||||||
|
else:
|
||||||
|
print('Continue training model...')
|
@@ -0,0 +1,9 @@
|
|||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
import torch
|
||||||
|
x_train = torch.randn(500, 20) # 全部数据加载
|
||||||
|
y_train = torch.randn(500, 1) # 全部数据加载
|
||||||
|
train_dataset = TensorDataset(x_train, y_train)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
for batch_x, batch_y in train_loader:
|
||||||
|
print(batch_x.shape)
|
||||||
|
print(batch_y.shape)
|
@@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def load_data_from_file(filename):
|
||||||
|
x_train = torch.randn(100, 20) # 小文件加载
|
||||||
|
y_train = torch.randn(100, 1) # 小文件加载
|
||||||
|
return x_train, y_train
|
||||||
|
|
||||||
|
x_train_new = [] # 新的小文件
|
||||||
|
y_train_new = [] # 新的小文件
|
||||||
|
n = 5
|
||||||
|
for i0_new in range(n):
|
||||||
|
for i0 in range(n):
|
||||||
|
x_train, y_train = load_data_from_file(filename=str(i0))
|
||||||
|
if i0 == 0:
|
||||||
|
x_train_new = x_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]
|
||||||
|
y_train_new = y_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]
|
||||||
|
else:
|
||||||
|
x_train_new = torch.cat((x_train_new, x_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]), dim=0)
|
||||||
|
y_train_new = torch.cat((y_train_new, y_train[i0_new*int(100/n):(i0_new+1)*int(100/n), :]), dim=0)
|
||||||
|
print(x_train_new.shape)
|
||||||
|
print(y_train_new.shape)
|
||||||
|
print('Save new file!')
|
Reference in New Issue
Block a user