0.1.107
This commit is contained in:
		| @@ -1,7 +1,7 @@ | |||||||
| [metadata] | [metadata] | ||||||
| # replace with your username: | # replace with your username: | ||||||
| name = guan | name = guan | ||||||
| version = 0.1.106 | version = 0.1.107 | ||||||
| author = guanjihuan | author = guanjihuan | ||||||
| author_email = guanjihuan@163.com | author_email = guanjihuan@163.com | ||||||
| description = An open source python package | description = An open source python package | ||||||
|   | |||||||
| @@ -1,6 +1,6 @@ | |||||||
| Metadata-Version: 2.1 | Metadata-Version: 2.1 | ||||||
| Name: guan | Name: guan | ||||||
| Version: 0.1.106 | Version: 0.1.107 | ||||||
| Summary: An open source python package | Summary: An open source python package | ||||||
| Home-page: https://py.guanjihuan.com | Home-page: https://py.guanjihuan.com | ||||||
| Author: guanjihuan | Author: guanjihuan | ||||||
|   | |||||||
| @@ -254,6 +254,14 @@ def save_model(model, filename='./model.pth'): | |||||||
|     import torch |     import torch | ||||||
|     torch.save(model, filename) |     torch.save(model, filename) | ||||||
|  |  | ||||||
|  | # 以字典的形式保存模型的所有信息到文件(保存时需要模型的类可访问) | ||||||
|  | def save_model_with_all_information(model, model_class, note='', filename='./model_with_all_information.pth'): | ||||||
|  |     import torch | ||||||
|  |     checkpoint = {'model_state_dict': model.state_dict(), | ||||||
|  |                   'model_class': model_class, | ||||||
|  |                   'note': note,} | ||||||
|  |     torch.save(checkpoint, filename) | ||||||
|  |  | ||||||
| # 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变) | # 加载模型参数(需要输入模型,加载后,原输入的模型参数也会改变) | ||||||
| def load_model_parameters(model, filename='./model_parameters.pth'): | def load_model_parameters(model, filename='./model_parameters.pth'): | ||||||
|     import torch |     import torch | ||||||
| @@ -266,6 +274,18 @@ def load_model(filename='./model.pth'): | |||||||
|     model = torch.load(filename) |     model = torch.load(filename) | ||||||
|     return model |     return model | ||||||
|  |  | ||||||
|  | # 加载包含所有信息的模型(包含了模型的类,返回的是对象) | ||||||
|  | def load_model_with_all_information(filename='./model_with_all_information.pth', note_print=0): | ||||||
|  |     import torch | ||||||
|  |     checkpoint = torch.load(filename) | ||||||
|  |     model_class = checkpoint['model_class'] | ||||||
|  |     model = model_class() | ||||||
|  |     model.load_state_dict(checkpoint['model_state_dict']) | ||||||
|  |     if note_print==1: | ||||||
|  |         note = checkpoint['note'] | ||||||
|  |         print(note) | ||||||
|  |     return model | ||||||
|  |  | ||||||
| # 加载训练数据,用于批量加载训练 | # 加载训练数据,用于批量加载训练 | ||||||
| def load_train_data(x_train, y_train, batch_size=32): | def load_train_data(x_train, y_train, batch_size=32): | ||||||
|     from torch.utils.data import DataLoader, TensorDataset |     from torch.utils.data import DataLoader, TensorDataset | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user