0.1.107
This commit is contained in:
parent
c91c34f76b
commit
0363b5fcf6
@ -1,7 +1,7 @@
|
||||
[metadata]
|
||||
# replace with your username:
|
||||
name = guan
|
||||
version = 0.1.106
|
||||
version = 0.1.107
|
||||
author = guanjihuan
|
||||
author_email = guanjihuan@163.com
|
||||
description = An open source python package
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: guan
|
||||
Version: 0.1.106
|
||||
Version: 0.1.107
|
||||
Summary: An open source python package
|
||||
Home-page: https://py.guanjihuan.com
|
||||
Author: guanjihuan
|
||||
|
@ -254,6 +254,14 @@ def save_model(model, filename='./model.pth'):
|
||||
import torch
|
||||
torch.save(model, filename)
|
||||
|
||||
# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问)
|
||||
def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'):
|
||||
import torch
|
||||
checkpoint = {'model_state_dict': model.state_dict(),
|
||||
'model_class': model_class,
|
||||
'note': note,}
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
# 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变)
|
||||
def load_model_parameters(model, filename='./model_parameters.pth'):
|
||||
import torch
|
||||
@ -266,6 +274,18 @@ def load_model(filename='./model.pth'):
|
||||
model = torch.load(filename)
|
||||
return model
|
||||
|
||||
# 加载包含所有信息的模型(包含了模型的类,返回的是对象)
|
||||
def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0):
|
||||
import torch
|
||||
checkpoint = torch.load(filename)
|
||||
model_class = checkpoint['model_class']
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if note_print==1:
|
||||
note = checkpoint['note']
|
||||
print(note)
|
||||
return model
|
||||
|
||||
# 加载训练数据,用于批量加载训练
|
||||
def load_train_data(x_train, y_train, batch_size=32):
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
Loading…
x
Reference in New Issue
Block a user