# Pytorch 函数 .expand ( )
其将单个维度扩大成更大维度,返回一个新的 tensor,具体看下例:
import torch | |
a = torch.Tensor([[1], [2], [3],[4]]) | |
# 未使用 expand()函数前的 a | |
print('a.size: ', a.size()) | |
print('a: ', a) | |
b = a.expand(4, 2) | |
# 使用 expand()函数后的输出 | |
print('a.size: ', a.size()) | |
print('a: ', a) | |
print('b.size: ', b.size()) | |
print('b: ', b) |
expand()函数使用前后 a 没有发生变化,输出都是:
a.size:>torch.Size([4, 1])
a:>1
2
3
4
[torch.FloatTensor of size 4x1]
b 的输出为:
b.size:>torch.Size([4, 2])
b:>1 1
2 2
3 3
4 4
[torch.FloatTensor of size 4x2]
由此得出结论,a 通过 expand()函数扩展某一维度后自身不会发生变化
a = torch.Tensor([[[[1,2], [2,3], [3,4],[4,5]]]]) | |
b = a.expand(2, 1, 4, 2) | |
c = a.expand(1, 2, 4, 2) | |
# 使用 expand()函数后的输出 | |
print('a.size: ', a.size()) | |
print('b.size: ', b.size()) | |
print('b: ', b) | |
print('c.size: ', c.size()) | |
print('c: ', c) | |
b2 = b.expand(3, 1, 4, 2) # b: torch.Size([2, 1, 4, 2]) | |
print('b2.size: ', b2.size()) |
输出:
a.size:>torch.Size([1, 1, 4, 2])
b.size:>torch.Size([2, 1, 4, 2])
b:>(0 ,0 ,.,.) =
1 2
2 3
3 4
4 5
(1 ,0 ,.,.) =
1 2
2 3
3 4
4 5
[torch.FloatTensor of size 2x1x4x2]
c.size:>torch.Size([1, 2, 4, 2])
c:>(0 ,0 ,.,.) =
1 2
2 3
3 4
4 5
(0 ,1 ,.,.) =
1 2
2 3
3 4
4 5
[torch.FloatTensor of size 1x2x4x2]
b2 输出:
Traceback (most recent call last):
File “”, line 1, in
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0. at /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/TH/generic/THTensor.c:309
由此可见,只要是单维度均可进行扩展,但是若非单维度会报错
上面的单维度即维度是 1