0.1.108
This commit is contained in:
parent
0363b5fcf6
commit
6f78f12717
@ -1,7 +1,7 @@
|
||||
[metadata]
|
||||
# replace with your username:
|
||||
name = guan
|
||||
version = 0.1.107
|
||||
version = 0.1.108
|
||||
author = guanjihuan
|
||||
author_email = guanjihuan@163.com
|
||||
description = An open source python package
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: guan
|
||||
Version: 0.1.107
|
||||
Version: 0.1.108
|
||||
Summary: An open source python package
|
||||
Home-page: https://py.guanjihuan.com
|
||||
Author: guanjihuan
|
||||
|
@ -254,11 +254,18 @@ def save_model(model, filename='./model.pth'):
|
||||
import torch
|
||||
torch.save(model, filename)
|
||||
|
||||
# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问)
|
||||
def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'):
|
||||
# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问,此外还要输入模型的实例化函数)
|
||||
def save_model_with_all_information(model, model_class, model_instantiation, note='', filename='./model_with_all_information.pth'):
|
||||
import torch
|
||||
import guan
|
||||
model_class_source = guan.get_source(name=model_class)
|
||||
model_class_source = 'import torch\n'+model_class_source
|
||||
model_instantiation_source = guan.get_source(name=model_instantiation)
|
||||
checkpoint = {'model_state_dict': model.state_dict(),
|
||||
'model_class': model_class,
|
||||
'model_class_name': model_class.__name__,
|
||||
'model_class_source': model_class_source,
|
||||
'model_instantiation_name':model_instantiation.__name__,
|
||||
'model_instantiation_source': model_instantiation_source,
|
||||
'note': note,}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
@ -274,12 +281,19 @@ def load_model(filename='./model.pth'):
|
||||
model = torch.load(filename)
|
||||
return model
|
||||
|
||||
# 加载包含所有信息的模型(包含了模型的类,返回的是对象)
|
||||
# 加载包含所有信息的模型(包含了模型的类和实例化函数等,返回的是模型对象)
|
||||
def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0):
|
||||
import torch
|
||||
checkpoint = torch.load(filename)
|
||||
model_class = checkpoint['model_class']
|
||||
model = model_class()
|
||||
model_class_source = checkpoint['model_class_source']
|
||||
exec(model_class_source, globals())
|
||||
# model_class_name = checkpoint['model_class_name']
|
||||
# model_class = globals()[model_class_name]
|
||||
model_instantiation_source = checkpoint['model_instantiation_source']
|
||||
exec(model_instantiation_source, globals())
|
||||
model_instantiation_name = checkpoint['model_instantiation_name']
|
||||
model_instantiation = globals()[model_instantiation_name]
|
||||
model = model_instantiation()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if note_print==1:
|
||||
note = checkpoint['note']
|
||||
|
@ -40,6 +40,12 @@ def chat(prompt='你好', model=1, stream=0, top_p=0.8, temperature=0.85):
|
||||
print('\n--- End Stream Message ---\n')
|
||||
return response
|
||||
|
||||
# 获取函数或类的源码(返回字符串)
|
||||
def get_source(name):
|
||||
import inspect
|
||||
source = inspect.getsource(name)
|
||||
return source
|
||||
|
||||
# 获取当前日期字符串
|
||||
def get_date(bar=True):
|
||||
import datetime
|
||||
@ -464,12 +470,6 @@ def get_PID(name):
|
||||
id_running = ps_ef[1]
|
||||
return id_running
|
||||
|
||||
# 获取函数的源码
|
||||
def get_function_source(function_name):
|
||||
import inspect
|
||||
function_source = inspect.getsource(function_name)
|
||||
return function_source
|
||||
|
||||
# 查找文件名相同的文件
|
||||
def find_repeated_file_with_same_filename(directory='./', ignored_directory_with_words=[], ignored_file_with_words=[], num=1000):
|
||||
import os
|
||||
|
Loading…
x
Reference in New Issue
Block a user