0.1.74
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| [metadata] | [metadata] | ||||||
| # replace with your username: | # replace with your username: | ||||||
| name = guan | name = guan | ||||||
| version = 0.1.73 | version = 0.1.74 | ||||||
| author = guanjihuan | author = guanjihuan | ||||||
| author_email = guanjihuan@163.com | author_email = guanjihuan@163.com | ||||||
| description = An open source python package | description = An open source python package | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| Metadata-Version: 2.1 | Metadata-Version: 2.1 | ||||||
| Name: guan | Name: guan | ||||||
| Version: 0.1.73 | Version: 0.1.74 | ||||||
| Summary: An open source python package | Summary: An open source python package | ||||||
| Home-page: https://py.guanjihuan.com | Home-page: https://py.guanjihuan.com | ||||||
| Author: guanjihuan | Author: guanjihuan | ||||||
|   | |||||||
| @@ -131,7 +131,6 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr | |||||||
|     elif criterion == 'CrossEntropyLoss': |     elif criterion == 'CrossEntropyLoss': | ||||||
|         criterion = torch.nn.CrossEntropyLoss() |         criterion = torch.nn.CrossEntropyLoss() | ||||||
|     losses = [] |     losses = [] | ||||||
|  |  | ||||||
|     for epoch in range(num_epochs): |     for epoch in range(num_epochs): | ||||||
|         output = model.forward(x_data) |         output = model.forward(x_data) | ||||||
|         loss = criterion(output, y_data) |         loss = criterion(output, y_data) | ||||||
| @@ -141,7 +140,34 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr | |||||||
|         losses.append(loss.item()) |         losses.append(loss.item()) | ||||||
|         if print_show == 1: |         if print_show == 1: | ||||||
|             if (epoch + 1) % 100 == 0: |             if (epoch + 1) % 100 == 0: | ||||||
|                 print(epoch) |                 print(epoch, loss.item()) | ||||||
|  |     return model, losses | ||||||
|  |  | ||||||
|  | # 使用优化器批量训练模型 | ||||||
|  | @guan.statistics_decorator | ||||||
|  | def batch_train_model(model, train_loader, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1): | ||||||
|  |     import torch | ||||||
|  |     if optimizer == 'Adam': | ||||||
|  |         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||||||
|  |     elif optimizer == 'SGD': | ||||||
|  |         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | ||||||
|  |      | ||||||
|  |     if criterion == 'MSELoss': | ||||||
|  |         criterion = torch.nn.MSELoss() | ||||||
|  |     elif criterion == 'CrossEntropyLoss': | ||||||
|  |         criterion = torch.nn.CrossEntropyLoss() | ||||||
|  |     losses = [] | ||||||
|  |     for epoch in range(num_epochs): | ||||||
|  |         for batch_x, batch_y in train_loader: | ||||||
|  |             output = model.forward(batch_x) | ||||||
|  |             loss = criterion(output, batch_y) | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |             losses.append(loss.item()) | ||||||
|  |             if print_show == 1: | ||||||
|  |                 if (epoch + 1) % 100 == 0: | ||||||
|  |                     print(epoch, loss.item()) | ||||||
|     return model, losses |     return model, losses | ||||||
|  |  | ||||||
| # 保存模型参数到文件 | # 保存模型参数到文件 | ||||||
| @@ -169,3 +195,11 @@ def load_model(filename='./model.pth'): | |||||||
|     import torch |     import torch | ||||||
|     model = torch.load(filename) |     model = torch.load(filename) | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
|  | # 加载训练数据,用于批量加载训练 | ||||||
|  | @guan.statistics_decorator | ||||||
|  | def load_train_data(x_train, y_train, batch_size=32): | ||||||
|  |     from torch.utils.data import DataLoader, TensorDataset | ||||||
|  |     train_dataset = TensorDataset(x_train, y_train) | ||||||
|  |     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||||||
|  |     return train_loader | ||||||
		Reference in New Issue
	
	Block a user