0.1.59
This commit is contained in:
parent
1608ae7437
commit
5a568076c7
@ -1,7 +1,7 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
# replace with your username:
|
# replace with your username:
|
||||||
name = guan
|
name = guan
|
||||||
version = 0.1.58
|
version = 0.1.59
|
||||||
author = guanjihuan
|
author = guanjihuan
|
||||||
author_email = guanjihuan@163.com
|
author_email = guanjihuan@163.com
|
||||||
description = An open source python package
|
description = An open source python package
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
Metadata-Version: 2.1
|
Metadata-Version: 2.1
|
||||||
Name: guan
|
Name: guan
|
||||||
Version: 0.1.58
|
Version: 0.1.59
|
||||||
Summary: An open source python package
|
Summary: An open source python package
|
||||||
Home-page: https://py.guanjihuan.com
|
Home-page: https://py.guanjihuan.com
|
||||||
Author: guanjihuan
|
Author: guanjihuan
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
# Module: machine_learning
|
# Module: machine_learning
|
||||||
import guan
|
import guan
|
||||||
|
|
||||||
# 全连接神经网络模型(包含一个隐藏层)
|
# 全连接神经网络模型(包含一个隐藏层)(模型的类定义成全局的)
|
||||||
@guan.function_decorator
|
@guan.function_decorator
|
||||||
def fully_connected_neural_network_with_one_hidden_layer(input_size=1, hidden_size=10, output_size=1, activation='relu'):
|
def fully_connected_neural_network_with_one_hidden_layer(input_size=1, hidden_size=10, output_size=1, activation='relu'):
|
||||||
import torch
|
import torch
|
||||||
# 如果在函数中定义模型类,尽量定义成全局的,这样可以防止在保存完整模型到文件时,无法访问函数中的模型类。
|
|
||||||
global model_class_of_fully_connected_neural_network_with_one_hidden_layer
|
global model_class_of_fully_connected_neural_network_with_one_hidden_layer
|
||||||
class model_class_of_fully_connected_neural_network_with_one_hidden_layer(torch.nn.Module):
|
class model_class_of_fully_connected_neural_network_with_one_hidden_layer(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -28,7 +27,7 @@ def fully_connected_neural_network_with_one_hidden_layer(input_size=1, hidden_si
|
|||||||
model = model_class_of_fully_connected_neural_network_with_one_hidden_layer()
|
model = model_class_of_fully_connected_neural_network_with_one_hidden_layer()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 全连接神经网络模型(包含两个隐藏层)
|
# 全连接神经网络模型(包含两个隐藏层)(模型的类定义成全局的)
|
||||||
@guan.function_decorator
|
@guan.function_decorator
|
||||||
def fully_connected_neural_network_with_two_hidden_layers(input_size=1, hidden_size_1=10, hidden_size_2=10, output_size=1, activation_1='relu', activation_2='relu'):
|
def fully_connected_neural_network_with_two_hidden_layers(input_size=1, hidden_size_1=10, hidden_size_2=10, output_size=1, activation_1='relu', activation_2='relu'):
|
||||||
import torch
|
import torch
|
||||||
@ -67,7 +66,7 @@ def fully_connected_neural_network_with_two_hidden_layers(input_size=1, hidden_s
|
|||||||
model = model_class_of_fully_connected_neural_network_with_two_hidden_layers()
|
model = model_class_of_fully_connected_neural_network_with_two_hidden_layers()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 全连接神经网络模型(包含三个隐藏层)
|
# 全连接神经网络模型(包含三个隐藏层)(模型的类定义成全局的)
|
||||||
@guan.function_decorator
|
@guan.function_decorator
|
||||||
def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden_size_1=10, hidden_size_2=10, hidden_size_3=10, output_size=1, activation_1='relu', activation_2='relu', activation_3='relu'):
|
def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden_size_1=10, hidden_size_2=10, hidden_size_3=10, output_size=1, activation_1='relu', activation_2='relu', activation_3='relu'):
|
||||||
import torch
|
import torch
|
||||||
@ -145,28 +144,28 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr
|
|||||||
print(epoch)
|
print(epoch)
|
||||||
return model, losses
|
return model, losses
|
||||||
|
|
||||||
# 保存完整模型到文件
|
|
||||||
@guan.function_decorator
|
|
||||||
def save_model(model, filename='./model.pth'):
|
|
||||||
import torch
|
|
||||||
torch.save(model, filename)
|
|
||||||
|
|
||||||
# 保存模型参数到文件
|
# 保存模型参数到文件
|
||||||
@guan.function_decorator
|
@guan.function_decorator
|
||||||
def save_model_parameters(model, filename='./model_parameters.pth'):
|
def save_model_parameters(model, filename='./model_parameters.pth'):
|
||||||
import torch
|
import torch
|
||||||
torch.save(model.state_dict(), filename)
|
torch.save(model.state_dict(), filename)
|
||||||
|
|
||||||
# 加载完整模型
|
# 保存完整模型到文件(保存时需要模型的类可访问)
|
||||||
@guan.function_decorator
|
@guan.function_decorator
|
||||||
def load_model(filename='./model.pth'):
|
def save_model(model, filename='./model.pth'):
|
||||||
import torch
|
import torch
|
||||||
model = torch.load(filename)
|
torch.save(model, filename)
|
||||||
return model
|
|
||||||
|
|
||||||
# 加载模型参数(需要输入模型)
|
# 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变)
|
||||||
@guan.function_decorator
|
@guan.function_decorator
|
||||||
def load_model_parameters(model, filename='./model_parameters.pth'):
|
def load_model_parameters(model, filename='./model_parameters.pth'):
|
||||||
import torch
|
import torch
|
||||||
model.load_state_dict(torch.load(filename))
|
model.load_state_dict(torch.load(filename))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# 加载完整模型(不需要输入模型,但加载时需要原定义的模型的类可访问)
|
||||||
|
@guan.function_decorator
|
||||||
|
def load_model(filename='./model.pth'):
|
||||||
|
import torch
|
||||||
|
model = torch.load(filename)
|
||||||
|
return model
|
Loading…
x
Reference in New Issue
Block a user