# Pytorch 函数 .expand ( )

其将单个维度扩大成更大维度,返回一个新的 tensor,具体看下例:

import torch
a = torch.Tensor([[1], [2], [3],[4]])
# 未使用 expand()函数前的 a
print('a.size: ', a.size())
print('a: ', a)
b = a.expand(4, 2)
# 使用 expand()函数后的输出
print('a.size: ', a.size())
print('a: ', a)
print('b.size: ', b.size())
print('b: ', b)

expand()函数使用前后 a 没有发生变化,输出都是:

a.size:>torch.Size([4, 1])

a:>1

2
3
4
[torch.FloatTensor of size 4x1]

b 的输出为:

b.size:>torch.Size([4, 2])

b:>1 1

2 2
3 3
4 4
[torch.FloatTensor of size 4x2]

由此得出结论,a 通过 expand()函数扩展某一维度后自身不会发生变化

a = torch.Tensor([[[[1,2], [2,3], [3,4],[4,5]]]])
b = a.expand(2, 1, 4, 2)
c = a.expand(1, 2, 4, 2)
# 使用 expand()函数后的输出
print('a.size: ', a.size())
print('b.size: ', b.size())
print('b: ', b)
print('c.size: ', c.size())
print('c: ', c)
 b2 = b.expand(3, 1, 4, 2)  # b: torch.Size([2, 1, 4, 2])
 print('b2.size: ', b2.size())

输出:

a.size:>torch.Size([1, 1, 4, 2])

b.size:>torch.Size([2, 1, 4, 2])

b:>(0 ,0 ,.,.) =

1 2
2 3
3 4
4 5
(1 ,0 ,.,.) =
1 2
2 3
3 4
4 5
[torch.FloatTensor of size 2x1x4x2]

c.size:>torch.Size([1, 2, 4, 2])

c:>(0 ,0 ,.,.) =

1 2
2 3
3 4
4 5
(0 ,1 ,.,.) =
1 2
2 3
3 4
4 5
[torch.FloatTensor of size 1x2x4x2]

b2 输出:

Traceback (most recent call last):
File “”, line 1, in
RuntimeError: The expanded size of the tensor (3) must match the existing size (2) at non-singleton dimension 0. at /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/TH/generic/THTensor.c:309

由此可见,只要是单维度均可进行扩展,但是若非单维度会报错

上面的单维度即维度是 1