This commit is contained in:
guanjihuan 2024-04-03 00:06:01 +08:00
parent 7a049ccc5e
commit 5e3e22e25d
2 changed files with 23 additions and 0 deletions

View File

@ -0,0 +1,10 @@
import torch
input_data = torch.randn(1, 1, 28, 28)
conv_layer = torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1)
output_data = conv_layer(input_data)
print("输出数据的形状:", output_data.shape)
conv_layer = torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)
output_data = conv_layer(input_data)
print("输出数据的形状:", output_data.shape)

View File

@ -0,0 +1,13 @@
import torch
input_data = torch.tensor([[7, 3, 5, 2],
[8, 7, 1, 6],
[4, 9, 3, 9],
[0, 8, 4, 5]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # 两次 .unsqueeze(0) 分别是添加批次和通道维度
max_pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
output_data = max_pool(input_data)
print("Input:\n", input_data)
print("Output after max pooling:\n", output_data)
print('输入数据的形状: ', input_data.shape)
print('输出数据的形状: ', output_data.shape)