0.1.52
This commit is contained in:
parent
daf6372453
commit
0bfffa7ceb
@ -1,7 +1,7 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
# replace with your username:
|
# replace with your username:
|
||||||
name = guan
|
name = guan
|
||||||
version = 0.1.51
|
version = 0.1.52
|
||||||
author = guanjihuan
|
author = guanjihuan
|
||||||
author_email = guanjihuan@163.com
|
author_email = guanjihuan@163.com
|
||||||
description = An open source python package
|
description = An open source python package
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
Metadata-Version: 2.1
|
Metadata-Version: 2.1
|
||||||
Name: guan
|
Name: guan
|
||||||
Version: 0.1.51
|
Version: 0.1.52
|
||||||
Summary: An open source python package
|
Summary: An open source python package
|
||||||
Home-page: https://py.guanjihuan.com
|
Home-page: https://py.guanjihuan.com
|
||||||
Author: guanjihuan
|
Author: guanjihuan
|
||||||
|
@ -113,3 +113,56 @@ def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden
|
|||||||
return output
|
return output
|
||||||
model = model_class()
|
model = model_class()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# 使用优化器训练模型
|
||||||
|
@guan.function_decorator
|
||||||
|
def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1):
|
||||||
|
import torch
|
||||||
|
if optimizer == 'Adam':
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||||
|
elif optimizer == 'SGD':
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
if criterion == 'MSELoss':
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
elif criterion == 'CrossEntropyLoss':
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
output = model.forward(x_data)
|
||||||
|
loss = criterion(output, y_data)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
losses.append(loss.item())
|
||||||
|
if print_show == 1:
|
||||||
|
if (epoch + 1) % 100 == 0:
|
||||||
|
print(epoch)
|
||||||
|
return model, losses
|
||||||
|
|
||||||
|
# 保存完整模型到文件
|
||||||
|
@guan.function_decorator
|
||||||
|
def save_model(model, filename='./model.pth'):
|
||||||
|
import torch
|
||||||
|
torch.save(model, filename)
|
||||||
|
|
||||||
|
# 保存模型参数到文件
|
||||||
|
@guan.function_decorator
|
||||||
|
def save_model_parameters(model, filename='./model_parameters.pth'):
|
||||||
|
import torch
|
||||||
|
torch.save(model.state_dict(), filename)
|
||||||
|
|
||||||
|
# 加载完整模型
|
||||||
|
@guan.function_decorator
|
||||||
|
def load_model(filename='./model.pth'):
|
||||||
|
import torch
|
||||||
|
model = torch.load(filename)
|
||||||
|
return model
|
||||||
|
|
||||||
|
# 加载模型参数(需要输入模型)
|
||||||
|
@guan.function_decorator
|
||||||
|
def load_model_parameters(model, filename='./model_parameters.pth'):
|
||||||
|
import torch
|
||||||
|
model.load_state_dict(torch.load(filename))
|
||||||
|
return model
|
Loading…
x
Reference in New Issue
Block a user