pytorch 中的 transpose 方法的作用是交换矩阵的两个维度,transpose (dim0, dim1) → Tensor,其和 torch.transpose () 函数作用一样。 torch.transpose ():
torch.transpose(input, dim0, dim1) → Tensor |
Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped. The resulting out tensor shares it’s underlying storage with the input tensor, so changing the content of one would change the content of the other. 第二条是说输出和输入是共享一块内存的,所以两者同时改变。
Parameters | |
input (Tensor) – the input tensor. | |
dim0 (int) – the first dimension to be transposed | |
dim1 (int) – the second dimension to be transposed |
例:
>>> x = torch.randn(2, 3) | |
>>> x | |
tensor([[ 1.0028, -0.9893, 0.5809], | |
[-0.1669, 0.7299, 0.4942]]) | |
>>> torch.transpose(x, 0, 1) | |
tensor([[ 1.0028, -0.1669], | |
[-0.9893, 0.7299], | |
[ 0.5809, 0.4942]]) |
需要注意的几点: 1、transpose 中的两个维度参数的顺序是可以交换位置的,即 transpose(x, 0, 1,) 和 transpose(x, 1, 0)效果是相同的。如下:
>>> import torch | |
>>> x = torch.randn(2, 3) | |
>>> x | |
tensor([[-0.4343, 0.4643, -1.1345], | |
[-0.3667, -1.9913, 1.3485]]) | |
>>> torch.transpose(x, 1, 0) | |
tensor([[-0.4343, -0.3667], | |
[ 0.4643, -1.9913], | |
[-1.1345, 1.3485]]) | |
>>> torch.transpose(x, 0, 1) | |
tensor([[-0.4343, -0.3667], | |
[ 0.4643, -1.9913], | |
[-1.1345, 1.3485]]) |
2、transpose.() 中只有两个参数,而 torch.transpose()函数中有三个参数。