Create torchviz_example.py
This commit is contained in:
parent
0547992527
commit
ad506ef82b
50
2024.04.26_torchviz_example/torchviz_example.py
Normal file
50
2024.04.26_torchviz_example/torchviz_example.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
This code is supported by the website: https://www.guanjihuan.com
|
||||
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/40353
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torchviz
|
||||
|
||||
|
||||
# 简单网络的例子
|
||||
class SimpleNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 6)
|
||||
self.fc2 = torch.nn.Linear(6, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
model = SimpleNet()
|
||||
input = torch.randn(1, 10)
|
||||
output = model(input)
|
||||
graph = torchviz.make_dot(output, params=dict(model.named_parameters()))
|
||||
graph.render("Simple_net_graph") # 保存计算图为 PDF 文件
|
||||
|
||||
|
||||
# 卷积网络的例子
|
||||
class ConvolutionalNeuralNetwork(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.convolutional_layer_1 = torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=3, stride=1, padding=1)
|
||||
self.convolutional_layer_2 = torch.nn.Conv2d(in_channels=10, out_channels=10, kernel_size=3, stride=1, padding=1)
|
||||
self.pooling_layer = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.hidden_layer_1 = torch.nn.Linear(in_features=10*7*7, out_features=10)
|
||||
self.hidden_layer_2 = torch.nn.Linear(in_features=10, out_features=10)
|
||||
self.output_layer = torch.nn.Linear(in_features=10, out_features=1)
|
||||
def forward(self, x):
|
||||
channel_output_1 = torch.nn.functional.relu(self.pooling_layer(self.convolutional_layer_1(x)))
|
||||
channel_output_2 = torch.nn.functional.relu(self.pooling_layer(self.convolutional_layer_2(channel_output_1)))
|
||||
channel_output_2 = torch.flatten(channel_output_2, 1)
|
||||
hidden_output_1 = torch.nn.functional.relu(self.hidden_layer_1(channel_output_2))
|
||||
hidden_output_2 = torch.nn.functional.relu(self.hidden_layer_2(hidden_output_1))
|
||||
output = self.output_layer(hidden_output_2)
|
||||
return output
|
||||
model = ConvolutionalNeuralNetwork()
|
||||
input = torch.randn(15, 1, 28, 28)
|
||||
output = model(input)
|
||||
graph = torchviz.make_dot(output, params=dict(model.named_parameters()))
|
||||
graph.render("CNN_graph") # 保存计算图为 PDF 文件
|
Loading…
x
Reference in New Issue
Block a user