# 1. 函数介绍
torch.argmax(input, dim=None, keepdim=False)
返回指定维度最大值的序号
dim 给定的定义是:
the demention to reduce
. 也就是把 dim 这个维度的,变成这个维度的最大值的 index。dim 的不同值表示不同维度。特别的在 dim=0 表示二维中的列,dim=1 在二维矩阵中表示行。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么 dim=0 就表示对应到 d0 也就是第一个维度,dim=1 表示对应到也就是第二个维度,依次类推。
知道 dim 的值是什么意思还不行,还要知道函数中这个 dim 给出来会发生什么?
举例说明:
例子 1: torch.argmax()
函数中 dim 表示该维度会消失。
这个消失是什么意思?
官方英文解释是: dim (int) – the dimension to reduce.
我们知道 argmax 就是得到最大值的序号索引,对于一个维度为 (d0,d1) 的矩阵来说,我们想要求每一行中最大数的在该行中的列号,最后我们得到的就是一个维度为 (d0,1) 的一维矩阵。这时候,列这一维度就要消失了。
因此,我们想要求每一行最大的列标号,我们就要指定 dim=1,表示我们不要列了,保留行的 size 就可以了。
假如我们想求每一列的最大行标,就可以指定 dim=0,表示我们不要行了,求出每一列的最大值的下标,最后得到(1,d1)的一维矩阵。
# 2. 实例演示
实例 1:
import torch | |
a = torch.tensor( | |
[ | |
[1, 5, 5, 2], | |
[9, -6, 2, 8], | |
[-3, 7, -9, 1] | |
]) | |
b = torch.argmax(a, dim=0) | |
print(b) | |
print(a.shape) |
输出结果:
tensor([1, 2, 0, 1]) | |
torch.Size([3, 4]) |
dim=0
的维度为 3,即在那 3 组数据中作比较,求得是每一列中的最大行标,因此为 [1,2,0,4]。
实例 2:
import torch | |
a = torch.tensor([ | |
[ | |
[1, 5, 5, 2], | |
[9, -6, 2, 8], | |
[-3, 7, -9, 1] | |
], | |
[ | |
[-1, 7, -5, 2], | |
[9, 6, 2, 8], | |
[3, 7, 9, 1] | |
]]) | |
b = torch.argmax(a, dim=0) | |
print(b) | |
print(b.shape) | |
print(a.shape) | |
""" | |
tensor([[0, 1, 0, 1], | |
[1, 1, 1, 1], | |
[1, 1, 1, 1]]) | |
torch.Size([3, 4]) | |
torch.Size([2, 3, 4]) | |
""" | |
# dim=0, 即将第一个维度消除,也就是将两个 [3*4] 矩阵只保留一个,因此要在两组中作比较,即将上下两个 [3*4] 的矩阵分别在对应的位置上比较大小,为 0 表示上面的 [3x4] 矩阵对应的数字大,为 1 则表示下面的 [3x4] 矩阵对应的数字大,由此可类比到如果有更多的 [3*4] 矩阵进行比较时的情形。dim=0 表示的第一个维度在 a 中的值为 2,即一共有 2 个 [3x4] 的矩阵 | |
b = torch.argmax(a, dim=1) | |
""" | |
tensor([[1, 2, 0, 1], | |
[1, 2, 2, 1]]) | |
torch.Size([2, 3, 4]) | |
""" | |
# dim=1,即将第二个维度消除,这么理解:矩阵维度变为 [2*4]; | |
""" | |
[1, 5, 5, 2], | |
[9, -6, 2, 8], | |
[-3, 7, -9, 1]; | |
纵向压缩成一维,因此变为[1,2,0,1];同理得到[1,2,2,1]; | |
a的shape为torch.Size([2, 3, 4]),dim=1表示在3上,构成3的分别是第一个[3x4]矩阵中的[1, 5, 5, 2]、[9, -6, 2, 8]、[-3, 7, -9, 1]和第二个[3x4]矩阵中的[-1, 7, -5, 2]、[9, 6, 2, 8]、[3, 7, 9, 1] | |
分别在它们之间做比较,得到的值的范围是0~2 | |
""" | |
b = torch.argmax(a,dim=2) | |
""" | |
tensor([[2, 0, 1], | |
[1, 0, 2]]) | |
""" | |
# dim=2, 即将第三个维度消除,这么理解:矩阵维度变为 [2*3] | |
""" | |
[1, 5, 5, 2], | |
[9, -6, 2, 8], | |
[-3, 7, -9, 1]; | |
横向压缩成一维 | |
[2,0,1],同理得到下面的""" |