0.1.108
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| [metadata] | [metadata] | ||||||
| # replace with your username: | # replace with your username: | ||||||
| name = guan | name = guan | ||||||
| version = 0.1.107 | version = 0.1.108 | ||||||
| author = guanjihuan | author = guanjihuan | ||||||
| author_email = guanjihuan@163.com | author_email = guanjihuan@163.com | ||||||
| description = An open source python package | description = An open source python package | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| Metadata-Version: 2.1 | Metadata-Version: 2.1 | ||||||
| Name: guan | Name: guan | ||||||
| Version: 0.1.107 | Version: 0.1.108 | ||||||
| Summary: An open source python package | Summary: An open source python package | ||||||
| Home-page: https://py.guanjihuan.com | Home-page: https://py.guanjihuan.com | ||||||
| Author: guanjihuan | Author: guanjihuan | ||||||
|   | |||||||
| @@ -254,11 +254,18 @@ def save_model(model, filename='./model.pth'): | |||||||
|     import torch |     import torch | ||||||
|     torch.save(model, filename) |     torch.save(model, filename) | ||||||
|  |  | ||||||
| # 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问) | # 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问,此外还要输入模型的实例化函数) | ||||||
| def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'): | def save_model_with_all_information(model, model_class, model_instantiation, note='', filename='./model_with_all_information.pth'): | ||||||
|     import torch |     import torch | ||||||
|  |     import guan | ||||||
|  |     model_class_source = guan.get_source(name=model_class) | ||||||
|  |     model_class_source = 'import torch\n'+model_class_source | ||||||
|  |     model_instantiation_source = guan.get_source(name=model_instantiation) | ||||||
|     checkpoint = {'model_state_dict': model.state_dict(), |     checkpoint = {'model_state_dict': model.state_dict(), | ||||||
|                   'model_class': model_class, |                   'model_class_name': model_class.__name__, | ||||||
|  |                   'model_class_source': model_class_source, | ||||||
|  |                   'model_instantiation_name':model_instantiation.__name__, | ||||||
|  |                   'model_instantiation_source': model_instantiation_source, | ||||||
|                   'note': note,} |                   'note': note,} | ||||||
|     torch.save(checkpoint, filename) |     torch.save(checkpoint, filename) | ||||||
|  |  | ||||||
| @@ -274,12 +281,19 @@ def load_model(filename='./model.pth'): | |||||||
|     model = torch.load(filename) |     model = torch.load(filename) | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
| # 加载包含所有信息的模型(包含了模型的类,返回的是对象) | # 加载包含所有信息的模型(包含了模型的类和实例化函数等,返回的是模型对象) | ||||||
| def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0): | def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0): | ||||||
|     import torch |     import torch | ||||||
|     checkpoint = torch.load(filename) |     checkpoint = torch.load(filename) | ||||||
|     model_class = checkpoint['model_class'] |     model_class_source = checkpoint['model_class_source'] | ||||||
|     model = model_class() |     exec(model_class_source, globals()) | ||||||
|  |     # model_class_name = checkpoint['model_class_name'] | ||||||
|  |     # model_class = globals()[model_class_name] | ||||||
|  |     model_instantiation_source = checkpoint['model_instantiation_source'] | ||||||
|  |     exec(model_instantiation_source, globals()) | ||||||
|  |     model_instantiation_name = checkpoint['model_instantiation_name'] | ||||||
|  |     model_instantiation = globals()[model_instantiation_name] | ||||||
|  |     model = model_instantiation() | ||||||
|     model.load_state_dict(checkpoint['model_state_dict']) |     model.load_state_dict(checkpoint['model_state_dict']) | ||||||
|     if note_print==1: |     if note_print==1: | ||||||
|         note = checkpoint['note'] |         note = checkpoint['note'] | ||||||
|   | |||||||
| @@ -40,6 +40,12 @@ def chat(prompt='你好', model=1, stream=0, top_p=0.8, temperature=0.85): | |||||||
|             print('\n--- End Stream Message ---\n') |             print('\n--- End Stream Message ---\n') | ||||||
|     return response |     return response | ||||||
|  |  | ||||||
|  | # 获取函数或类的源码(返回字符串) | ||||||
|  | def get_source(name): | ||||||
|  |     import inspect | ||||||
|  |     source = inspect.getsource(name) | ||||||
|  |     return source | ||||||
|  |  | ||||||
| # 获取当前日期字符串 | # 获取当前日期字符串 | ||||||
| def get_date(bar=True): | def get_date(bar=True): | ||||||
|     import datetime |     import datetime | ||||||
| @@ -464,12 +470,6 @@ def get_PID(name): | |||||||
|     id_running = ps_ef[1] |     id_running = ps_ef[1] | ||||||
|     return id_running |     return id_running | ||||||
|  |  | ||||||
| # 获取函数的源码 |  | ||||||
| def get_function_source(function_name): |  | ||||||
|     import inspect |  | ||||||
|     function_source = inspect.getsource(function_name) |  | ||||||
|     return function_source |  | ||||||
|  |  | ||||||
| # 查找文件名相同的文件 | # 查找文件名相同的文件 | ||||||
| def find_repeated_file_with_same_filename(directory='./', ignored_directory_with_words=[], ignored_file_with_words=[], num=1000): | def find_repeated_file_with_same_filename(directory='./', ignored_directory_with_words=[], ignored_file_with_words=[], num=1000): | ||||||
|     import os |     import os | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user