0.1.74
This commit is contained in:
		| @@ -1,7 +1,7 @@ | ||||
| [metadata] | ||||
| # replace with your username: | ||||
| name = guan | ||||
| version = 0.1.73 | ||||
| version = 0.1.74 | ||||
| author = guanjihuan | ||||
| author_email = guanjihuan@163.com | ||||
| description = An open source python package | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| Metadata-Version: 2.1 | ||||
| Name: guan | ||||
| Version: 0.1.73 | ||||
| Version: 0.1.74 | ||||
| Summary: An open source python package | ||||
| Home-page: https://py.guanjihuan.com | ||||
| Author: guanjihuan | ||||
|   | ||||
| @@ -131,7 +131,6 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr | ||||
|     elif criterion == 'CrossEntropyLoss': | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     losses = [] | ||||
|  | ||||
|     for epoch in range(num_epochs): | ||||
|         output = model.forward(x_data) | ||||
|         loss = criterion(output, y_data) | ||||
| @@ -141,7 +140,34 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr | ||||
|         losses.append(loss.item()) | ||||
|         if print_show == 1: | ||||
|             if (epoch + 1) % 100 == 0: | ||||
|                 print(epoch) | ||||
|                 print(epoch, loss.item()) | ||||
|     return model, losses | ||||
|  | ||||
| # 使用优化器批量训练模型 | ||||
| @guan.statistics_decorator | ||||
| def batch_train_model(model, train_loader, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1): | ||||
|     import torch | ||||
|     if optimizer == 'Adam': | ||||
|         optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | ||||
|     elif optimizer == 'SGD': | ||||
|         optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) | ||||
|      | ||||
|     if criterion == 'MSELoss': | ||||
|         criterion = torch.nn.MSELoss() | ||||
|     elif criterion == 'CrossEntropyLoss': | ||||
|         criterion = torch.nn.CrossEntropyLoss() | ||||
|     losses = [] | ||||
|     for epoch in range(num_epochs): | ||||
|         for batch_x, batch_y in train_loader: | ||||
|             output = model.forward(batch_x) | ||||
|             loss = criterion(output, batch_y) | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             losses.append(loss.item()) | ||||
|             if print_show == 1: | ||||
|                 if (epoch + 1) % 100 == 0: | ||||
|                     print(epoch, loss.item()) | ||||
|     return model, losses | ||||
|  | ||||
| # 保存模型参数到文件 | ||||
| @@ -168,4 +194,12 @@ def load_model_parameters(model, filename='./model_parameters.pth'): | ||||
| def load_model(filename='./model.pth'): | ||||
|     import torch | ||||
|     model = torch.load(filename) | ||||
|     return model | ||||
|     return model | ||||
|  | ||||
| # 加载训练数据,用于批量加载训练 | ||||
| @guan.statistics_decorator | ||||
| def load_train_data(x_train, y_train, batch_size=32): | ||||
|     from torch.utils.data import DataLoader, TensorDataset | ||||
|     train_dataset = TensorDataset(x_train, y_train) | ||||
|     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||||
|     return train_loader | ||||
		Reference in New Issue
	
	Block a user