Update pytorch_module_class_method.py
This commit is contained in:
parent
a9254b09eb
commit
ba96316e25
@ -47,7 +47,7 @@ print()
|
|||||||
import torch
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
class LinearRegressionModel(torch.nn.Module): # 定义模型,继承torch.nn.Module类
|
class One_Model(torch.nn.Module): # 定义模型,继承torch.nn.Module类
|
||||||
def __init__(self, input_size, hidden_size, output_size):
|
def __init__(self, input_size, hidden_size, output_size):
|
||||||
super().__init__() # 调用父类的的初始化内容
|
super().__init__() # 调用父类的的初始化内容
|
||||||
self.hidden_layer = torch.nn.Linear(input_size, hidden_size) # 定义一个隐藏层
|
self.hidden_layer = torch.nn.Linear(input_size, hidden_size) # 定义一个隐藏层
|
||||||
@ -85,7 +85,7 @@ plt.show() # 显示图像
|
|||||||
input_size = 1
|
input_size = 1
|
||||||
hidden_size = 50
|
hidden_size = 50
|
||||||
output_size = 1
|
output_size = 1
|
||||||
model = LinearRegressionModel(input_size, hidden_size, output_size) # 创建模型
|
model = One_Model(input_size, hidden_size, output_size) # 创建模型
|
||||||
criterion = torch.nn.MSELoss() # 定义损失函数
|
criterion = torch.nn.MSELoss() # 定义损失函数
|
||||||
learning_rate = 0.01 # 梯度下降的学习速率
|
learning_rate = 0.01 # 梯度下降的学习速率
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 定义优化器。其中.parameters()是torch.nn.Module类中的方法
|
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 定义优化器。其中.parameters()是torch.nn.Module类中的方法
|
||||||
@ -121,10 +121,18 @@ plt.ylabel('Loss')
|
|||||||
plt.legend()
|
plt.legend()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
print("\n\n\nModel Parameters:\n") # 查看参数
|
||||||
|
for param in model.parameters():
|
||||||
|
print(param)
|
||||||
|
|
||||||
|
print("\n\n\nGradients:\n") # 查看梯度
|
||||||
|
for param in model.parameters():
|
||||||
|
print(param.grad)
|
||||||
|
|
||||||
torch.save(model.state_dict(), 'model.pth') # 使用 torch.save 函数来保存模型。其中,model.state_dict()返回模型的权重字典
|
torch.save(model.state_dict(), 'model.pth') # 使用 torch.save 函数来保存模型。其中,model.state_dict()返回模型的权重字典
|
||||||
torch.save(model, 'full_model.pth') # 保存整个模型,包括模型的结构和权重
|
torch.save(model, 'full_model.pth') # 保存整个模型,包括模型的结构和权重
|
||||||
|
|
||||||
model_2 = LinearRegressionModel(input_size, hidden_size, output_size) # 创建模型
|
model_2 = One_Model(input_size, hidden_size, output_size) # 创建模型
|
||||||
model_2.load_state_dict(torch.load('model.pth')) # 加载模型参数
|
model_2.load_state_dict(torch.load('model.pth')) # 加载模型参数
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
predictions_2 = model_2(x_data)
|
predictions_2 = model_2(x_data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user