From 6f78f1271723470662ecc518411326bdb393dc9b Mon Sep 17 00:00:00 2001 From: guanjihuan Date: Sun, 9 Jun 2024 07:56:07 +0800 Subject: [PATCH] 0.1.108 --- PyPI/setup.cfg | 2 +- PyPI/src/guan.egg-info/PKG-INFO | 2 +- PyPI/src/guan/machine_learning.py | 26 ++++++++++++++++++++------ PyPI/src/guan/others.py | 12 ++++++------ 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/PyPI/setup.cfg b/PyPI/setup.cfg index b3c0a85..556ff54 100644 --- a/PyPI/setup.cfg +++ b/PyPI/setup.cfg @@ -1,7 +1,7 @@ [metadata] # replace with your username: name = guan -version = 0.1.107 +version = 0.1.108 author = guanjihuan author_email = guanjihuan@163.com description = An open source python package diff --git a/PyPI/src/guan.egg-info/PKG-INFO b/PyPI/src/guan.egg-info/PKG-INFO index 39f5c1f..6f8b933 100644 --- a/PyPI/src/guan.egg-info/PKG-INFO +++ b/PyPI/src/guan.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: guan -Version: 0.1.107 +Version: 0.1.108 Summary: An open source python package Home-page: https://py.guanjihuan.com Author: guanjihuan diff --git a/PyPI/src/guan/machine_learning.py b/PyPI/src/guan/machine_learning.py index f47fd53..455f084 100644 --- a/PyPI/src/guan/machine_learning.py +++ b/PyPI/src/guan/machine_learning.py @@ -254,11 +254,18 @@ def save_model(model, filename='./model.pth'): import torch torch.save(model, filename) -# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问) -def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'): +# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问,此外还要输入模型的实例化函数) +def save_model_with_all_information(model, model_class, model_instantiation, note='', filename='./model_with_all_information.pth'): import torch + import guan + model_class_source = guan.get_source(name=model_class) + model_class_source = 'import torch\n'+model_class_source + model_instantiation_source = guan.get_source(name=model_instantiation) checkpoint = {'model_state_dict': model.state_dict(), - 'model_class': model_class, + 'model_class_name': model_class.__name__, + 'model_class_source': model_class_source, + 'model_instantiation_name':model_instantiation.__name__, + 'model_instantiation_source': model_instantiation_source, 'note': note,} torch.save(checkpoint, filename) @@ -274,12 +281,19 @@ def load_model(filename='./model.pth'): model = torch.load(filename) return model -# 加载包含所有信息的模型(包含了模型的类,返回的是对象) +# 加载包含所有信息的模型(包含了模型的类和实例化函数等,返回的是模型对象) def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0): import torch checkpoint = torch.load(filename) - model_class = checkpoint['model_class'] - model = model_class() + model_class_source = checkpoint['model_class_source'] + exec(model_class_source, globals()) + # model_class_name = checkpoint['model_class_name'] + # model_class = globals()[model_class_name] + model_instantiation_source = checkpoint['model_instantiation_source'] + exec(model_instantiation_source, globals()) + model_instantiation_name = checkpoint['model_instantiation_name'] + model_instantiation = globals()[model_instantiation_name] + model = model_instantiation() model.load_state_dict(checkpoint['model_state_dict']) if note_print==1: note = checkpoint['note'] diff --git a/PyPI/src/guan/others.py b/PyPI/src/guan/others.py index c1d1c25..3577399 100644 --- a/PyPI/src/guan/others.py +++ b/PyPI/src/guan/others.py @@ -40,6 +40,12 @@ def chat(prompt='你好', model=1, stream=0, top_p=0.8, temperature=0.85): print('\n--- End Stream Message ---\n') return response +# 获取函数或类的源码(返回字符串) +def get_source(name): + import inspect + source = inspect.getsource(name) + return source + # 获取当前日期字符串 def get_date(bar=True): import datetime @@ -464,12 +470,6 @@ def get_PID(name): id_running = ps_ef[1] return id_running -# 获取函数的源码 -def get_function_source(function_name): - import inspect - function_source = inspect.getsource(function_name) - return function_source - # 查找文件名相同的文件 def find_repeated_file_with_same_filename(directory='./', ignored_directory_with_words=[], ignored_file_with_words=[], num=1000): import os