This commit is contained in:
guanjihuan 2024-06-09 05:26:20 +08:00
parent c91c34f76b
commit 0363b5fcf6
3 changed files with 22 additions and 2 deletions

View File

@ -1,7 +1,7 @@
[metadata]
# replace with your username:
name = guan
version = 0.1.106
version = 0.1.107
author = guanjihuan
author_email = guanjihuan@163.com
description = An open source python package

View File

@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: guan
Version: 0.1.106
Version: 0.1.107
Summary: An open source python package
Home-page: https://py.guanjihuan.com
Author: guanjihuan

View File

@ -254,6 +254,14 @@ def save_model(model, filename='./model.pth'):
import torch
torch.save(model, filename)
# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问)
def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'):
import torch
checkpoint = {'model_state_dict': model.state_dict(),
'model_class': model_class,
'note': note,}
torch.save(checkpoint, filename)
# 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变)
def load_model_parameters(model, filename='./model_parameters.pth'):
import torch
@ -266,6 +274,18 @@ def load_model(filename='./model.pth'):
model = torch.load(filename)
return model
# 加载包含所有信息的模型(包含了模型的类,返回的是对象)
def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0):
import torch
checkpoint = torch.load(filename)
model_class = checkpoint['model_class']
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
if note_print==1:
note = checkpoint['note']
print(note)
return model
# 加载训练数据,用于批量加载训练
def load_train_data(x_train, y_train, batch_size=32):
from torch.utils.data import DataLoader, TensorDataset