0.1.107
This commit is contained in:
parent
c91c34f76b
commit
0363b5fcf6
@ -1,7 +1,7 @@
|
|||||||
[metadata]
|
[metadata]
|
||||||
# replace with your username:
|
# replace with your username:
|
||||||
name = guan
|
name = guan
|
||||||
version = 0.1.106
|
version = 0.1.107
|
||||||
author = guanjihuan
|
author = guanjihuan
|
||||||
author_email = guanjihuan@163.com
|
author_email = guanjihuan@163.com
|
||||||
description = An open source python package
|
description = An open source python package
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
Metadata-Version: 2.1
|
Metadata-Version: 2.1
|
||||||
Name: guan
|
Name: guan
|
||||||
Version: 0.1.106
|
Version: 0.1.107
|
||||||
Summary: An open source python package
|
Summary: An open source python package
|
||||||
Home-page: https://py.guanjihuan.com
|
Home-page: https://py.guanjihuan.com
|
||||||
Author: guanjihuan
|
Author: guanjihuan
|
||||||
|
@ -254,6 +254,14 @@ def save_model(model, filename='./model.pth'):
|
|||||||
import torch
|
import torch
|
||||||
torch.save(model, filename)
|
torch.save(model, filename)
|
||||||
|
|
||||||
|
# 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问)
|
||||||
|
def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'):
|
||||||
|
import torch
|
||||||
|
checkpoint = {'model_state_dict': model.state_dict(),
|
||||||
|
'model_class': model_class,
|
||||||
|
'note': note,}
|
||||||
|
torch.save(checkpoint, filename)
|
||||||
|
|
||||||
# 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变)
|
# 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变)
|
||||||
def load_model_parameters(model, filename='./model_parameters.pth'):
|
def load_model_parameters(model, filename='./model_parameters.pth'):
|
||||||
import torch
|
import torch
|
||||||
@ -266,6 +274,18 @@ def load_model(filename='./model.pth'):
|
|||||||
model = torch.load(filename)
|
model = torch.load(filename)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
# 加载包含所有信息的模型(包含了模型的类,返回的是对象)
|
||||||
|
def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0):
|
||||||
|
import torch
|
||||||
|
checkpoint = torch.load(filename)
|
||||||
|
model_class = checkpoint['model_class']
|
||||||
|
model = model_class()
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
if note_print==1:
|
||||||
|
note = checkpoint['note']
|
||||||
|
print(note)
|
||||||
|
return model
|
||||||
|
|
||||||
# 加载训练数据,用于批量加载训练
|
# 加载训练数据,用于批量加载训练
|
||||||
def load_train_data(x_train, y_train, batch_size=32):
|
def load_train_data(x_train, y_train, batch_size=32):
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
Loading…
x
Reference in New Issue
Block a user