update
This commit is contained in:
40
2024.06.03_pytorch_tensor_cat/pytorch_cat.py
Normal file
40
2024.06.03_pytorch_tensor_cat/pytorch_cat.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
This code is supported by the website: https://www.guanjihuan.com
|
||||
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41194
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
# 定义两个张量
|
||||
tensor1 = torch.randn(2, 3)
|
||||
tensor2 = torch.randn(2, 3)
|
||||
print(tensor1)
|
||||
print(tensor2)
|
||||
print()
|
||||
|
||||
# 第一维度的数据合并(需要其他的维度保持一致)
|
||||
result1 = torch.cat((tensor1, tensor2), dim=0)
|
||||
print(result1)
|
||||
print(result1.shape)
|
||||
print()
|
||||
|
||||
# 第二维度的数据合并(需要其他的维度保持一致)
|
||||
result2 = torch.cat((tensor1, tensor2), dim=1)
|
||||
print(result2)
|
||||
print(result2.shape)
|
||||
print()
|
||||
|
||||
|
||||
|
||||
# 定义多个张量
|
||||
tensor1 = torch.randn(2, 10)
|
||||
tensor2 = torch.randn(2, 20)
|
||||
tensor3 = torch.randn(2, 30)
|
||||
tensor4 = torch.randn(2, 50)
|
||||
|
||||
# 将这些张量放在一个列表中
|
||||
tensors = [tensor1, tensor2, tensor3, tensor4]
|
||||
|
||||
# 第二维度的数据合并(确保所有张量的第一维度相同)
|
||||
result3 = torch.cat(tensors, dim=1)
|
||||
print(result3.shape)
|
Reference in New Issue
Block a user