diff --git a/PyPI/setup.cfg b/PyPI/setup.cfg index cfeb048..58ab677 100644 --- a/PyPI/setup.cfg +++ b/PyPI/setup.cfg @@ -1,7 +1,7 @@ [metadata] # replace with your username: name = guan -version = 0.1.73 +version = 0.1.74 author = guanjihuan author_email = guanjihuan@163.com description = An open source python package diff --git a/PyPI/src/guan.egg-info/PKG-INFO b/PyPI/src/guan.egg-info/PKG-INFO index fba844d..3c202b2 100644 --- a/PyPI/src/guan.egg-info/PKG-INFO +++ b/PyPI/src/guan.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: guan -Version: 0.1.73 +Version: 0.1.74 Summary: An open source python package Home-page: https://py.guanjihuan.com Author: guanjihuan diff --git a/PyPI/src/guan/machine_learning.py b/PyPI/src/guan/machine_learning.py index 2c8b2f7..722d5fa 100644 --- a/PyPI/src/guan/machine_learning.py +++ b/PyPI/src/guan/machine_learning.py @@ -131,7 +131,6 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr elif criterion == 'CrossEntropyLoss': criterion = torch.nn.CrossEntropyLoss() losses = [] - for epoch in range(num_epochs): output = model.forward(x_data) loss = criterion(output, y_data) @@ -141,7 +140,34 @@ def train_model(model, x_data, y_data, optimizer='Adam', learning_rate=0.001, cr losses.append(loss.item()) if print_show == 1: if (epoch + 1) % 100 == 0: - print(epoch) + print(epoch, loss.item()) + return model, losses + +# 使用优化器批量训练模型 +@guan.statistics_decorator +def batch_train_model(model, train_loader, optimizer='Adam', learning_rate=0.001, criterion='MSELoss', num_epochs=1000, print_show=1): + import torch + if optimizer == 'Adam': + optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + elif optimizer == 'SGD': + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + + if criterion == 'MSELoss': + criterion = torch.nn.MSELoss() + elif criterion == 'CrossEntropyLoss': + criterion = torch.nn.CrossEntropyLoss() + losses = [] + for epoch in range(num_epochs): + for batch_x, batch_y in train_loader: + output = model.forward(batch_x) + loss = criterion(output, batch_y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses.append(loss.item()) + if print_show == 1: + if (epoch + 1) % 100 == 0: + print(epoch, loss.item()) return model, losses # 保存模型参数到文件 @@ -168,4 +194,12 @@ def load_model_parameters(model, filename='./model_parameters.pth'): def load_model(filename='./model.pth'): import torch model = torch.load(filename) - return model \ No newline at end of file + return model + +# 加载训练数据,用于批量加载训练 +@guan.statistics_decorator +def load_train_data(x_train, y_train, batch_size=32): + from torch.utils.data import DataLoader, TensorDataset + train_dataset = TensorDataset(x_train, y_train) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + return train_loader \ No newline at end of file