Update pytorch_module_class_method.py

This commit is contained in:
guanjihuan 2023-11-22 16:12:25 +08:00
parent a9254b09eb
commit ba96316e25

View File

@ -47,7 +47,7 @@ print()
import torch
import matplotlib.pyplot as plt
class LinearRegressionModel(torch.nn.Module): # 定义模型继承torch.nn.Module类
class One_Model(torch.nn.Module): # 定义模型继承torch.nn.Module类
def __init__(self, input_size, hidden_size, output_size):
super().__init__() # 调用父类的的初始化内容
self.hidden_layer = torch.nn.Linear(input_size, hidden_size) # 定义一个隐藏层
@ -85,7 +85,7 @@ plt.show() # 显示图像
input_size = 1
hidden_size = 50
output_size = 1
model = LinearRegressionModel(input_size, hidden_size, output_size) # 创建模型
model = One_Model(input_size, hidden_size, output_size) # 创建模型
criterion = torch.nn.MSELoss() # 定义损失函数
learning_rate = 0.01 # 梯度下降的学习速率
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 定义优化器。其中.parameters()是torch.nn.Module类中的方法
@ -121,10 +121,18 @@ plt.ylabel('Loss')
plt.legend()
plt.show()
print("\n\n\nModel Parameters:\n") # 查看参数
for param in model.parameters():
print(param)
print("\n\n\nGradients:\n") # 查看梯度
for param in model.parameters():
print(param.grad)
torch.save(model.state_dict(), 'model.pth') # 使用 torch.save 函数来保存模型。其中model.state_dict()返回模型的权重字典
torch.save(model, 'full_model.pth') # 保存整个模型,包括模型的结构和权重
model_2 = LinearRegressionModel(input_size, hidden_size, output_size) # 创建模型
model_2 = One_Model(input_size, hidden_size, output_size) # 创建模型
model_2.load_state_dict(torch.load('model.pth')) # 加载模型参数
with torch.no_grad():
predictions_2 = model_2(x_data)