0.1.74
This commit is contained in:
parent
42054690a2
commit
aca7b1ebe5
@ -1,7 +1,7 @@
|
||||
[metadata]
|
||||
# replace with your username:
|
||||
name = guan
|
||||
version = 0.1.73
|
||||
version = 0.1.74
|
||||
author = guanjihuan
|
||||
author_email = guanjihuan@163.com
|
||||
description = An open source python package
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: guan
|
||||
Version: 0.1.73
|
||||
Version: 0.1.74
|
||||
Summary: An open source python package
|
||||
Home-page: https://py.guanjihuan.com
|
||||
Author: guanjihuan
|
||||
|
@ -131,7 +131,6 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr
|
||||
elif criterion == 'CrossEntropyLoss':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
output = model.forward(x_data)
|
||||
loss = criterion(output, y_data)
|
||||
@ -141,7 +140,34 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr
|
||||
losses.append(loss.item())
|
||||
if print_show == 1:
|
||||
if (epoch + 1) % 100 == 0:
|
||||
print(epoch)
|
||||
print(epoch, loss.item())
|
||||
return model, losses
|
||||
|
||||
# 使用优化器批量训练模型
|
||||
@guan.statistics_decorator
|
||||
def batch_train_model(model, train_loader, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1):
|
||||
import torch
|
||||
if optimizer == 'Adam':
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
elif optimizer == 'SGD':
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
||||
|
||||
if criterion == 'MSELoss':
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif criterion == 'CrossEntropyLoss':
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
losses = []
|
||||
for epoch in range(num_epochs):
|
||||
for batch_x, batch_y in train_loader:
|
||||
output = model.forward(batch_x)
|
||||
loss = criterion(output, batch_y)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
losses.append(loss.item())
|
||||
if print_show == 1:
|
||||
if (epoch + 1) % 100 == 0:
|
||||
print(epoch, loss.item())
|
||||
return model, losses
|
||||
|
||||
# 保存模型参数到文件
|
||||
@ -168,4 +194,12 @@ def load_model_parameters(model, filename='./model_parameters.pth'):
|
||||
def load_model(filename='./model.pth'):
|
||||
import torch
|
||||
model = torch.load(filename)
|
||||
return model
|
||||
return model
|
||||
|
||||
# 加载训练数据,用于批量加载训练
|
||||
@guan.statistics_decorator
|
||||
def load_train_data(x_train, y_train, batch_size=32):
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
train_dataset = TensorDataset(x_train, y_train)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
return train_loader
|
Loading…
x
Reference in New Issue
Block a user