This commit is contained in:
guanjihuan 2024-01-09 10:58:24 +08:00
parent 42054690a2
commit aca7b1ebe5
3 changed files with 39 additions and 5 deletions

View File

@ -1,7 +1,7 @@
[metadata]
# replace with your username:
name = guan
version = 0.1.73
version = 0.1.74
author = guanjihuan
author_email = guanjihuan@163.com
description = An open source python package

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: guan
Version: 0.1.73
Version: 0.1.74
Summary: An open source python package
Home-page: https://py.guanjihuan.com
Author: guanjihuan

View File

@ -131,7 +131,6 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr
elif criterion == 'CrossEntropyLoss':
criterion = torch.nn.CrossEntropyLoss()
losses = []
for epoch in range(num_epochs):
output = model.forward(x_data)
loss = criterion(output, y_data)
@ -141,7 +140,34 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr
losses.append(loss.item())
if print_show == 1:
if (epoch + 1) % 100 == 0:
print(epoch)
print(epoch, loss.item())
return model, losses
# 使用优化器批量训练模型
@guan.statistics_decorator
def batch_train_model(model, train_loader, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1):
import torch
if optimizer == 'Adam':
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
elif optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
if criterion == 'MSELoss':
criterion = torch.nn.MSELoss()
elif criterion == 'CrossEntropyLoss':
criterion = torch.nn.CrossEntropyLoss()
losses = []
for epoch in range(num_epochs):
for batch_x, batch_y in train_loader:
output = model.forward(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if print_show == 1:
if (epoch + 1) % 100 == 0:
print(epoch, loss.item())
return model, losses
# 保存模型参数到文件
@ -168,4 +194,12 @@ def load_model_parameters(model, filename='./model_parameters.pth'):
def load_model(filename='./model.pth'):
import torch
model = torch.load(filename)
return model
return model
# 加载训练数据,用于批量加载训练
@guan.statistics_decorator
def load_train_data(x_train, y_train, batch_size=32):
from torch.utils.data import DataLoader, TensorDataset
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
return train_loader