This commit is contained in:
guanjihuan 2023-11-28 20:28:43 +08:00
parent daf6372453
commit 0bfffa7ceb
3 changed files with 55 additions and 2 deletions

View File

@ -1,7 +1,7 @@
[metadata]
# replace with your username:
name = guan
version = 0.1.51
version = 0.1.52
author = guanjihuan
author_email = guanjihuan@163.com
description = An open source python package

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: guan
Version: 0.1.51
Version: 0.1.52
Summary: An open source python package
Home-page: https://py.guanjihuan.com
Author: guanjihuan

View File

@ -112,4 +112,57 @@ def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden
output = self.output_layer(hidden_output_3)
return output
model = model_class()
return model
# 使用优化器训练模型
@guan.function_decorator
def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1):
import torch
if optimizer == 'Adam':
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
elif optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
if criterion == 'MSELoss':
criterion = torch.nn.MSELoss()
elif criterion == 'CrossEntropyLoss':
criterion = torch.nn.CrossEntropyLoss()
losses = []
for epoch in range(num_epochs):
output = model.forward(x_data)
loss = criterion(output, y_data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if print_show == 1:
if (epoch + 1) % 100 == 0:
print(epoch)
return model, losses
# 保存完整模型到文件
@guan.function_decorator
def save_model(model, filename='./model.pth'):
import torch
torch.save(model, filename)
# 保存模型参数到文件
@guan.function_decorator
def save_model_parameters(model, filename='./model_parameters.pth'):
import torch
torch.save(model.state_dict(), filename)
# 加载完整模型
@guan.function_decorator
def load_model(filename='./model.pth'):
import torch
model = torch.load(filename)
return model
# 加载模型参数(需要输入模型)
@guan.function_decorator
def load_model_parameters(model, filename='./model_parameters.pth'):
import torch
model.load_state_dict(torch.load(filename))
return model