One of the common mistakes in Pytorch is wrong dimension. To avoid this, we need to know how to modify the dimension of the tensor to fit the requirement of the model.
Unsqueeze
Unsqueeze is used to add a dimension to a tensor.
# Unsqueeze demo
import torch
# Unsqueeze
x = torch.tensor([1 , 2 , 3 , 4 ])
print (x.shape)
print (x)
print ()
x = x.unsqueeze(0 )
print (x.shape)
print (x)
print ()
x = x.unsqueeze(1 )
print (x.shape)
print (x)
print ()
torch.Size([4])
tensor([1, 2, 3, 4])
torch.Size([1, 4])
tensor([[1, 2, 3, 4]])
torch.Size([1, 1, 4])
tensor([[[1, 2, 3, 4]]])
By default, the dimension is added at the beginning. For example, if the input tensor has shape (3, 4)
, the output tensor will have shape (1, 3, 4)
after unsqueeze.
We can also specify specific dimension to be added:
import torch
# Unsqueeze
x = torch.tensor([1 , 2 , 3 , 4 ])
print (x.shape)
print (x)
print ()
x = x.unsqueeze(0 )
print (x.shape)
print (x)
print ()
x = x.unsqueeze(2 )
print (x.shape)
print (x)
print ()
torch.Size([4])
tensor([1, 2, 3, 4])
torch.Size([1, 4])
tensor([[1, 2, 3, 4]])
torch.Size([1, 4, 1])
tensor([[[1],
[2],
[3],
[4]]])
Squeeze
Squeeze is used to remove a dimension from a tensor. It’s like squeezing a bottle of water, the bottle becomes smaller.
# Squeeze demo
import torch
x = torch.tensor([[[1 , 2 , 3 , 4 ]]])
print (x.shape)
print (x)
print ()
x = x.squeeze(0 )
print (x.shape)
print (x)
print ()
torch.Size([1, 1, 4])
tensor([[[1, 2, 3, 4]]])
torch.Size([1, 4])
tensor([[1, 2, 3, 4]])
If we don’t specify the dimension to be removed, squeeze will remove all the dimensions with size 1.
A x B x 1 x C x 1 x D
will become A x B x C x D
after squeeze.
import torch
x = torch.tensor([[[1 , 2 , 3 , 4 ]]])
print (x.shape)
print (x)
print ()
x = x.squeeze()
print (x.shape)
print (x)
print ()
torch.Size([1, 1, 4])
tensor([[[1, 2, 3, 4]]])
torch.Size([4])
tensor([1, 2, 3, 4])
If the input is of shape A x 1 x B x C x 1 x D
then the out tensor will be of shape: A x B x C x D
import torch
x = torch.ones(2 , 1 , 2 )
print (x.shape)
print (x)
print ()
x = x.squeeze(1 )
print (x.shape)
print (x)
print ()
torch.Size([2, 1, 2])
tensor([[[1., 1.]],
[[1., 1.]]])
torch.Size([2, 2])
tensor([[1., 1.],
[1., 1.]])
If the specified dimension does not have size 1, the input tensor is returned unchanged.
import torch
x = torch.ones(2 , 1 , 2 )
print (x.shape)
print (x)
print ()
x = x.squeeze(2 )
print (x.shape)
print (x)
print ()
torch.Size([2, 1, 2])
tensor([[[1., 1.]],
[[1., 1.]]])
torch.Size([2, 1, 2])
tensor([[[1., 1.]],
[[1., 1.]]])
Reshape
Reshape is used to change the shape of a tensor. It’s commonly used, so understanding it is very important.
## Reshape demo
import torch
x = torch.tensor([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]])
print (x.shape)
print (x)
print ()
x = x.reshape(4 , 2 )
print (x.shape)
print (x)
print ()
torch.Size([2, 4])
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
torch.Size([4, 2])
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
Let’s explore more:
import torch
x = torch.tensor([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]])
x = x.reshape(2 , 2 , 2 )
print (x.shape)
print (x)
print ()
x = x.reshape(2 , 2 , 2 , 1 )
print (x.shape)
print (x)
print ()
x = x.reshape(8 , 1 )
print (x.shape)
print (x)
print ()
torch.Size([2, 2, 2])
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
torch.Size([2, 2, 2, 1])
tensor([[[[1],
[2]],
[[3],
[4]]],
[[[5],
[6]],
[[7],
[8]]]])
torch.Size([8, 1])
tensor([[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8]])
What would be the output of the following?
import torch
x = torch.tensor([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]])
x = x.reshape(1 , 8 )
print (x.shape)
print (x)
print ()
How about this:
import torch
x = torch.tensor([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]])
x = x.reshape(4 )
print (x.shape)
print (x)
print ()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[26], line 4
1 import torch
3 x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
----> 4 x = x.reshape(4)
5 print(x.shape)
6 print(x)
RuntimeError : shape '[4]' is invalid for input of size 8
Flatten
Flatten is used to flatten a tensor.
# Flatten demo
import torch
x = torch.tensor([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]])
x = x.flatten()
print (x.shape)
print (x)
print ()
torch.Size([8])
tensor([1, 2, 3, 4, 5, 6, 7, 8])
Flatten will always convert a tensor into a 1D array. The dimension of the output tensor is 1.
# Multi-dimension tensor
import torch
x = torch.ones(2 , 2 , 2 , 2 )
print (x.shape)
print (x)
print ()
x = x.flatten()
print (x.shape)
print (x)
print ()
torch.Size([2, 2, 2, 2])
tensor([[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]],
[[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]]])
torch.Size([16])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Permute
Permute is used to change the order of the dimensions of a tensor.
The parameter is the new order of the dimensions
So, if we have a tensor of shape (A, B, C, D)
, the output tensor will have shape (D, C, B, A)
given parameter (3, 2, 1, 0)
.
# Permute demo
import torch
x = torch.ones(2 , 3 , 4 )
print (x.shape)
print (x)
print ()
x = x.permute(2 , 0 , 1 )
print (x.shape)
print (x)
print ()
torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]])
torch.Size([4, 2, 3])
tensor([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]])
Back to top