Update pytorch_module_class_method.py
This commit is contained in:
parent
a9254b09eb
commit
ba96316e25
@ -47,7 +47,7 @@ print()
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class LinearRegressionModel(torch.nn.Module): # 定义模型,继承torch.nn.Module类
|
||||
class One_Model(torch.nn.Module): # 定义模型,继承torch.nn.Module类
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
super().__init__() # 调用父类的的初始化内容
|
||||
self.hidden_layer = torch.nn.Linear(input_size, hidden_size) # 定义一个隐藏层
|
||||
@ -85,7 +85,7 @@ plt.show() # 显示图像
|
||||
input_size = 1
|
||||
hidden_size = 50
|
||||
output_size = 1
|
||||
model = LinearRegressionModel(input_size, hidden_size, output_size) # 创建模型
|
||||
model = One_Model(input_size, hidden_size, output_size) # 创建模型
|
||||
criterion = torch.nn.MSELoss() # 定义损失函数
|
||||
learning_rate = 0.01 # 梯度下降的学习速率
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 定义优化器。其中.parameters()是torch.nn.Module类中的方法
|
||||
@ -121,10 +121,18 @@ plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
print("\n\n\nModel Parameters:\n") # 查看参数
|
||||
for param in model.parameters():
|
||||
print(param)
|
||||
|
||||
print("\n\n\nGradients:\n") # 查看梯度
|
||||
for param in model.parameters():
|
||||
print(param.grad)
|
||||
|
||||
torch.save(model.state_dict(), 'model.pth') # 使用 torch.save 函数来保存模型。其中,model.state_dict()返回模型的权重字典
|
||||
torch.save(model, 'full_model.pth') # 保存整个模型,包括模型的结构和权重
|
||||
|
||||
model_2 = LinearRegressionModel(input_size, hidden_size, output_size) # 创建模型
|
||||
model_2 = One_Model(input_size, hidden_size, output_size) # 创建模型
|
||||
model_2.load_state_dict(torch.load('model.pth')) # 加载模型参数
|
||||
with torch.no_grad():
|
||||
predictions_2 = model_2(x_data)
|
||||
|
Loading…
x
Reference in New Issue
Block a user