diff --git a/PyPI/setup.cfg b/PyPI/setup.cfg index 11b822d..ca0b124 100644 --- a/PyPI/setup.cfg +++ b/PyPI/setup.cfg @@ -1,7 +1,7 @@ [metadata] # replace with your username: name = guan -version = 0.1.51 +version = 0.1.52 author = guanjihuan author_email = guanjihuan@163.com description = An open source python package diff --git a/PyPI/src/guan.egg-info/PKG-INFO b/PyPI/src/guan.egg-info/PKG-INFO index 10bfe32..fb4c059 100644 --- a/PyPI/src/guan.egg-info/PKG-INFO +++ b/PyPI/src/guan.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: guan -Version: 0.1.51 +Version: 0.1.52 Summary: An open source python package Home-page: https://py.guanjihuan.com Author: guanjihuan diff --git a/PyPI/src/guan/machine_learning.py b/PyPI/src/guan/machine_learning.py index d2ec692..3091100 100644 --- a/PyPI/src/guan/machine_learning.py +++ b/PyPI/src/guan/machine_learning.py @@ -112,4 +112,57 @@ def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden output = self.output_layer(hidden_output_3) return output model = model_class() + return model + +# 使用优化器训练模型 +@guan.function_decorator +def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1): + import torch + if optimizer == 'Adam': + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + elif optimizer == 'SGD': + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + if criterion == 'MSELoss': + criterion = torch.nn.MSELoss() + elif criterion == 'CrossEntropyLoss': + criterion = torch.nn.CrossEntropyLoss() + losses = [] + + for epoch in range(num_epochs): + output = model.forward(x_data) + loss = criterion(output, y_data) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses.append(loss.item()) + if print_show == 1: + if (epoch + 1) % 100 == 0: + print(epoch) + return model, losses + +# 保存完整模型到文件 +@guan.function_decorator +def save_model(model, filename='./model.pth'): + import torch + torch.save(model, filename) + +# 保存模型参数到文件 +@guan.function_decorator +def save_model_parameters(model, filename='./model_parameters.pth'): + import torch + torch.save(model.state_dict(), filename) + +# 加载完整模型 +@guan.function_decorator +def load_model(filename='./model.pth'): + import torch + model = torch.load(filename) + return model + +# 加载模型参数(需要输入模型) +@guan.function_decorator +def load_model_parameters(model, filename='./model_parameters.pth'): + import torch + model.load_state_dict(torch.load(filename)) return model \ No newline at end of file