目录
一. torch.stack()函数解析1. 函数说明:2. 代码举例总结一. torch.stack()函数解析
1. 函数说明:
1.1 官网:torch.stack(),函数定义及参数说明如下图所示:
(资料图)
1.2 函数功能
沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。
1.3 参数列表
tensors :为一系列输入张量,类型为turple和Listdim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接返回值:输出新增维度后的张量2. 代码举例
2.1 dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)
import torch #二维输入张量a,b a = torch.tensor([1, 2, 3]) b = torch.tensor([11, 22, 33]) c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维) print(a) print(b) print(c)
输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 2, 3],
[11, 22, 33]])
2.2 dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)
import torch #二维输入张量a,b a = torch.tensor([1, 2, 3]) b = torch.tensor([11, 22, 33]) c = torch.stack([a, b],dim=1)#在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维) print(a) print(b) print(c)
输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 11],
[ 2, 22],
[ 3, 33]])
2.3 dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。
import torch #二维输入张量a,b a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维) print(a) print(b) print(c)
输出结果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],[[11, 22, 33],
[44, 55, 66],
[77, 88, 99]]])
2.4 dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。
import torch #二维输入张量a,b a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b], 1)#在第1维进行连接,相当于对相应通道中每个行进行组合 print(a) print(b) print(c)
输出结果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[11, 22, 33]],[[ 4, 5, 6],
[44, 55, 66]],[[ 7, 8, 9],
[77, 88, 99]]])
2.5 dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。
import torch #二维输入张量a,b a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b], 2)#在第2维进行连接,相当于对相应行中每个列元素进行组合 print(a) print(b) print(c)
输出结果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 11],
[ 2, 22],
[ 3, 33]],[[ 4, 44],
[ 5, 55],
[ 6, 66]],[[ 7, 77],
[ 8, 88],
[ 9, 99]]])
2.6 dim=3:表示在第3维进行连接,相当于对相应行中每个列元素进行组合(输入维度大小为3维,因此dim=3最后一维始终代表为列),注意:此处输入张量维度为三维,因此dim最大只能为3。
import torch #三维输入张量a,b a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) c = torch.stack([a, b], 3)#表示在第3维进行连接,相当于对相应行中每个列元素进行组合(最后一维是第三维,始终代表为列) print(a) print(b) print(c)
输出结果如下所示:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
tensor([[[ 11, 22, 33],
[ 44, 55, 66],
[ 77, 88, 99]],[[110, 220, 330],
[440, 550, 660],
[770, 880, 990]]])
tensor([[[[ 1, 11],
[ 2, 22],
[ 3, 33]],[[ 4, 44],
[ 5, 55],
[ 6, 66]],[[ 7, 77],
[ 8, 88],
[ 9, 99]]],
[[[ 10, 110],
[ 20, 220],
[ 30, 330]],[[ 40, 440],
[ 50, 550],
[ 60, 660]],[[ 70, 770],
[ 80, 880],
[ 90, 990]]]])
2.7 dim=4(错误维度:因为此处输入张量维度为三维,所以dim最大只能为3,此处维度为4,因此会报错)
import torch #三维输入张量a,b a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) c = torch.stack([a, b], 4) print(a) print(b) print(c)
输出错误:
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)
总结
到此这篇关于Pytorch中torch.stack()函数的文章就介绍到这了,更多相关Pytorchtorch.stack()函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
X 关闭
X 关闭
- 1联想拯救者Y70发布最新预告:售价2970元起 迄今最便宜的骁龙8+旗舰
- 2亚马逊开始大规模推广掌纹支付技术 顾客可使用“挥手付”结账
- 3现代和起亚上半年出口20万辆新能源汽车同比增长30.6%
- 4如何让居民5分钟使用到各种设施?沙特“线性城市”来了
- 5AMD实现连续8个季度的增长 季度营收首次突破60亿美元利润更是翻倍
- 6转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 7充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 8好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 9名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?
- 10亚马逊云科技成立量子网络中心致力解决量子计算领域的挑战