13 lines
561 B
Python
13 lines
561 B
Python
import torch
|
|
input_data = torch.tensor([[7, 3, 5, 2],
|
|
[8, 7, 1, 6],
|
|
[4, 9, 3, 9],
|
|
[0, 8, 4, 5]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # 两次 .unsqueeze(0) 分别是添加批次和通道维度
|
|
|
|
max_pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
|
|
output_data = max_pool(input_data)
|
|
|
|
print("Input:\n", input_data)
|
|
print("Output after max pooling:\n", output_data)
|
|
print('输入数据的形状: ', input_data.shape)
|
|
print('输出数据的形状: ', output_data.shape) |