torch.matmul 是 tensor 的乘法,输入可以是高维的。 当输入是都是二维时,就是普通的矩阵乘法,和 tensor.mm 函数用法相同。
当输入有多维时,把多出的一维作为 batch 提出来,其他部分做矩阵乘法。
下面看一个两个都是 3 维的例子。
将 b 的第 0 维 1broadcast 成 2 提出来,后两维做矩阵乘法即可。 再看一个复杂一点的,是官网的例子。
首先把 a 的第 0 维 2 作为 batch 提出来,则 a 和 b 都可看作三维。再把 a 的 1broadcat 成 5,提取公因式 5。(这样说虽然不严谨,但是便于理解。)然后 a 剩下 (3,4),b 剩下 (4,2),做矩阵乘法得到 (3,2)。