0.1.52
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| [metadata] | [metadata] | ||||||
| # replace with your username: | # replace with your username: | ||||||
| name = guan | name = guan | ||||||
| version = 0.1.51 | version = 0.1.52 | ||||||
| author = guanjihuan | author = guanjihuan | ||||||
| author_email = guanjihuan@163.com | author_email = guanjihuan@163.com | ||||||
| description = An open source python package | description = An open source python package | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| Metadata-Version: 2.1 | Metadata-Version: 2.1 | ||||||
| Name: guan | Name: guan | ||||||
| Version: 0.1.51 | Version: 0.1.52 | ||||||
| Summary: An open source python package | Summary: An open source python package | ||||||
| Home-page: https://py.guanjihuan.com | Home-page: https://py.guanjihuan.com | ||||||
| Author: guanjihuan | Author: guanjihuan | ||||||
|   | |||||||
| @@ -113,3 +113,56 @@ def fully_connected_neural_network_with_three_hidden_layers(input_size=1, hidden | |||||||
|             return output |             return output | ||||||
|     model = model_class() |     model = model_class() | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
|  | # 使用优化器训练模型 | ||||||
|  | @guan.function_decorator | ||||||
|  | def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1): | ||||||
|  |     import torch | ||||||
|  |     if optimizer == 'Adam': | ||||||
|  |         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||||||
|  |     elif optimizer == 'SGD': | ||||||
|  |         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | ||||||
|  |      | ||||||
|  |     if criterion == 'MSELoss': | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |     elif criterion == 'CrossEntropyLoss': | ||||||
|  |         criterion = torch.nn.CrossEntropyLoss() | ||||||
|  |     losses = [] | ||||||
|  |  | ||||||
|  |     for epoch in range(num_epochs): | ||||||
|  |         output = model.forward(x_data) | ||||||
|  |         loss = criterion(output, y_data) | ||||||
|  |         optimizer.zero_grad() | ||||||
|  |         loss.backward() | ||||||
|  |         optimizer.step() | ||||||
|  |         losses.append(loss.item()) | ||||||
|  |         if print_show == 1: | ||||||
|  |             if (epoch + 1) % 100 == 0: | ||||||
|  |                 print(epoch) | ||||||
|  |     return model, losses | ||||||
|  |  | ||||||
|  | # 保存完整模型到文件 | ||||||
|  | @guan.function_decorator | ||||||
|  | def save_model(model, filename='./model.pth'): | ||||||
|  |     import torch | ||||||
|  |     torch.save(model, filename) | ||||||
|  |  | ||||||
|  | # 保存模型参数到文件 | ||||||
|  | @guan.function_decorator | ||||||
|  | def save_model_parameters(model, filename='./model_parameters.pth'): | ||||||
|  |     import torch | ||||||
|  |     torch.save(model.state_dict(), filename) | ||||||
|  |  | ||||||
|  | # 加载完整模型 | ||||||
|  | @guan.function_decorator | ||||||
|  | def load_model(filename='./model.pth'): | ||||||
|  |     import torch | ||||||
|  |     model = torch.load(filename) | ||||||
|  |     return model | ||||||
|  |  | ||||||
|  | # 加载模型参数(需要输入模型) | ||||||
|  | @guan.function_decorator | ||||||
|  | def load_model_parameters(model, filename='./model_parameters.pth'): | ||||||
|  |     import torch | ||||||
|  |     model.load_state_dict(torch.load(filename)) | ||||||
|  |     return model | ||||||
		Reference in New Issue
	
	Block a user