0.1.52
This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| Metadata-Version: 2.1 | ||||
| Name: guan | ||||
| Version: 0.1.51 | ||||
| Version: 0.1.52 | ||||
| Summary: An open source python package | ||||
| Home-page: https://py.guanjihuan.com | ||||
| Author: guanjihuan | ||||
|   | ||||
| @@ -112,4 +112,57 @@ def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden | ||||
|             output = self.output_layer(hidden_output_3) | ||||
|             return output | ||||
|     model = model_class() | ||||
|     return model | ||||
|  | ||||
| # 使用优化器训练模型 | ||||
| @guan.function_decorator | ||||
| def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1): | ||||
|     import torch | ||||
|     if optimizer == 'Adam': | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||||
|     elif optimizer == 'SGD': | ||||
|         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | ||||
|      | ||||
|     if criterion == 'MSELoss': | ||||
|         criterion = torch.nn.MSELoss() | ||||
|     elif criterion == 'CrossEntropyLoss': | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     losses = [] | ||||
|  | ||||
|     for epoch in range(num_epochs): | ||||
|         output = model.forward(x_data) | ||||
|         loss = criterion(output, y_data) | ||||
|         optimizer.zero_grad() | ||||
|         loss.backward() | ||||
|         optimizer.step() | ||||
|         losses.append(loss.item()) | ||||
|         if print_show == 1: | ||||
|             if (epoch + 1) % 100 == 0: | ||||
|                 print(epoch) | ||||
|     return model, losses | ||||
|  | ||||
| # 保存完整模型到文件 | ||||
| @guan.function_decorator | ||||
| def save_model(model, filename='./model.pth'): | ||||
|     import torch | ||||
|     torch.save(model, filename) | ||||
|  | ||||
| # 保存模型参数到文件 | ||||
| @guan.function_decorator | ||||
| def save_model_parameters(model, filename='./model_parameters.pth'): | ||||
|     import torch | ||||
|     torch.save(model.state_dict(), filename) | ||||
|  | ||||
| # 加载完整模型 | ||||
| @guan.function_decorator | ||||
| def load_model(filename='./model.pth'): | ||||
|     import torch | ||||
|     model = torch.load(filename) | ||||
|     return model | ||||
|  | ||||
| # 加载模型参数(需要输入模型) | ||||
| @guan.function_decorator | ||||
| def load_model_parameters(model, filename='./model_parameters.pth'): | ||||
|     import torch | ||||
|     model.load_state_dict(torch.load(filename)) | ||||
|     return model | ||||
		Reference in New Issue
	
	Block a user