在pytorch中,我们对张量Tensor的维度进行压缩或者扩充(被压缩或者扩充的维度为1),经常使用的是squeeze()
函数和unsqueeze()
函数
1. torch.squeeze(input, dim=None)
用于降维。将 input 中维度为1的部分去除,当维度大于等于2时,squeeze()无作用。
也可通过 input.squeeze( dim=None, out=None)调用。
- input(Tensor):输入张量,即被操作目标
- dim(int, optional):在指定维去掉一个维度。若不指定则自动寻找,指定则当指定的维度为1时去掉,不为1时则不改变
注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
示例
# 示例1
a = torch.Tensor(1,3)
>>
tensor([[-1.37,4.56,-3.57]])
print a.squeeze(0) #第一个维度大小是1,所以去除
>>
tensor([-1.37,4.56,-3.57])
print a.squeeze(1) ##第二个维度大小是3,所以不去除
>>
tensor([[-1.37,4.56,-3.57]])
# 示例2
c = torch.Tensor(3,1)
print c
>>
tensor([[-3.54],
[3.09],
[0.00]])
print c.squeeze(0)##第一个维度大小不是1,所以不去除
>>
tensor([[-3.54],
[3.09],
[0.00]])
print c.squeeze(1)#第二个维度大小是1,所以去除
>>
tensor([-3.54,3.09,0.00])
# 示例3
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
>>
torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
>>
torch.Size([2, 2, 2])
y = torch.squeeze(x, 0)
y.size()
>>
torch.Size([2, 1, 2, 1, 2])
2. torch.unsqueeze(input, dim)
为pytorch中的tensor增加一个维度。
也可通过 input.unsqueeze( dim=None, out=None)调用。
- input(Tensor):输入张量,即被操作目标
- dim(int, optional):在哪一维增加一个维度,dim必须被指定
示例
import torch
a = torch.arange(12).reshape([3,4])
print(a)
b = a.unsqueeze(1)
print(b)
>>
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
tensor([[[ 0, 1, 2, 3]],
[[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11]]])
参考官方文档: