From 0363b5fcf6655df5af4932df238023b649dc4337 Mon Sep 17 00:00:00 2001 From: guanjihuan Date: Sun, 9 Jun 2024 05:26:20 +0800 Subject: [PATCH] 0.1.107 --- PyPI/setup.cfg | 2 +- PyPI/src/guan.egg-info/PKG-INFO | 2 +- PyPI/src/guan/machine_learning.py | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/PyPI/setup.cfg b/PyPI/setup.cfg index 59f7c13..b3c0a85 100644 --- a/PyPI/setup.cfg +++ b/PyPI/setup.cfg @@ -1,7 +1,7 @@ [metadata] # replace with your username: name = guan -version = 0.1.106 +version = 0.1.107 author = guanjihuan author_email = guanjihuan@163.com description = An open source python package diff --git a/PyPI/src/guan.egg-info/PKG-INFO b/PyPI/src/guan.egg-info/PKG-INFO index 79b2fed..39f5c1f 100644 --- a/PyPI/src/guan.egg-info/PKG-INFO +++ b/PyPI/src/guan.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: guan -Version: 0.1.106 +Version: 0.1.107 Summary: An open source python package Home-page: https://py.guanjihuan.com Author: guanjihuan diff --git a/PyPI/src/guan/machine_learning.py b/PyPI/src/guan/machine_learning.py index d0c60f1..f47fd53 100644 --- a/PyPI/src/guan/machine_learning.py +++ b/PyPI/src/guan/machine_learning.py @@ -254,6 +254,14 @@ def save_model(model, filename='./model.pth'): import torch torch.save(model, filename) +# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问) +def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'): + import torch + checkpoint = {'model_state_dict': model.state_dict(), + 'model_class': model_class, + 'note': note,} + torch.save(checkpoint, filename) + # 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变) def load_model_parameters(model, filename='./model_parameters.pth'): import torch @@ -266,6 +274,18 @@ def load_model(filename='./model.pth'): model = torch.load(filename) return model +# 加载包含所有信息的模型(包含了模型的类,返回的是对象) +def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0): + import torch + checkpoint = torch.load(filename) + model_class = checkpoint['model_class'] + model = model_class() + model.load_state_dict(checkpoint['model_state_dict']) + if note_print==1: + note = checkpoint['note'] + print(note) + return model + # 加载训练数据,用于批量加载训练 def load_train_data(x_train, y_train, batch_size=32): from torch.utils.data import DataLoader, TensorDataset