Pytorch Dimension Modification

One of the common mistakes in Pytorch is wrong dimension. To avoid this, we need to know how to modify the dimension of the tensor to fit the requirement of the model.

Unsqueeze

Unsqueeze is used to add a dimension to a tensor.

# Unsqueeze demo
import torch

# Unsqueeze
x = torch.tensor([1, 2, 3, 4])
print(x.shape)
print(x)
print()

x = x.unsqueeze(0)
print(x.shape)
print(x)
print()

x = x.unsqueeze(1)
print(x.shape)
print(x)
print()
torch.Size([4])
tensor([1, 2, 3, 4])

torch.Size([1, 4])
tensor([[1, 2, 3, 4]])

torch.Size([1, 1, 4])
tensor([[[1, 2, 3, 4]]])

By default, the dimension is added at the beginning. For example, if the input tensor has shape (3, 4), the output tensor will have shape (1, 3, 4) after unsqueeze.

We can also specify specific dimension to be added:

import torch

# Unsqueeze
x = torch.tensor([1, 2, 3, 4])
print(x.shape)
print(x)
print()

x = x.unsqueeze(0)
print(x.shape)
print(x)
print()

x = x.unsqueeze(2)
print(x.shape)
print(x)
print()
torch.Size([4])
tensor([1, 2, 3, 4])

torch.Size([1, 4])
tensor([[1, 2, 3, 4]])

torch.Size([1, 4, 1])
tensor([[[1],
         [2],
         [3],
         [4]]])

Squeeze

Squeeze is used to remove a dimension from a tensor. It’s like squeezing a bottle of water, the bottle becomes smaller.

# Squeeze demo
import torch

x = torch.tensor([[[1, 2, 3, 4]]])

print(x.shape)
print(x)
print()

x = x.squeeze(0)
print(x.shape)
print(x)
print()
torch.Size([1, 1, 4])
tensor([[[1, 2, 3, 4]]])

torch.Size([1, 4])
tensor([[1, 2, 3, 4]])

If we don’t specify the dimension to be removed, squeeze will remove all the dimensions with size 1.

A x B x 1 x C x 1 x D will become A x B x C x D after squeeze.

import torch

x = torch.tensor([[[1, 2, 3, 4]]])

print(x.shape)
print(x)
print()

x = x.squeeze()
print(x.shape)
print(x)
print()
torch.Size([1, 1, 4])
tensor([[[1, 2, 3, 4]]])

torch.Size([4])
tensor([1, 2, 3, 4])

If the input is of shape A x 1 x B x C x 1 x D then the out tensor will be of shape: A x B x C x D

import torch

x = torch.ones(2, 1, 2)

print(x.shape)
print(x)
print()

x = x.squeeze(1)

print(x.shape)
print(x)
print()
torch.Size([2, 1, 2])
tensor([[[1., 1.]],

        [[1., 1.]]])

torch.Size([2, 2])
tensor([[1., 1.],
        [1., 1.]])

If the specified dimension does not have size 1, the input tensor is returned unchanged.

import torch

x = torch.ones(2, 1, 2)

print(x.shape)
print(x)
print()

x = x.squeeze(2)

print(x.shape)
print(x)
print()
torch.Size([2, 1, 2])
tensor([[[1., 1.]],

        [[1., 1.]]])

torch.Size([2, 1, 2])
tensor([[[1., 1.]],

        [[1., 1.]]])

Reshape

Reshape is used to change the shape of a tensor. It’s commonly used, so understanding it is very important.

## Reshape demo
import torch

x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
print(x.shape)
print(x)
print()

x = x.reshape(4, 2)
print(x.shape)
print(x)
print()
torch.Size([2, 4])
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

torch.Size([4, 2])
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

Let’s explore more:

import torch

x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])

x = x.reshape(2, 2, 2)
print(x.shape)
print(x)
print()

x = x.reshape(2, 2, 2, 1)
print(x.shape)
print(x)
print()

x = x.reshape(8, 1)
print(x.shape)
print(x)
print()
torch.Size([2, 2, 2])
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])

torch.Size([2, 2, 2, 1])
tensor([[[[1],
          [2]],

         [[3],
          [4]]],


        [[[5],
          [6]],

         [[7],
          [8]]]])

torch.Size([8, 1])
tensor([[1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8]])

What would be the output of the following?

import torch

x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
x = x.reshape(1, 8)
print(x.shape)
print(x)
print()

How about this:

import torch

x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
x = x.reshape(4)
print(x.shape)
print(x)
print()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[26], line 4
      1 import torch
      3 x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
----> 4 x = x.reshape(4)
      5 print(x.shape)
      6 print(x)

RuntimeError: shape '[4]' is invalid for input of size 8

Flatten

Flatten is used to flatten a tensor.

# Flatten demo

import torch

x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
x = x.flatten()
print(x.shape)
print(x)
print()
torch.Size([8])
tensor([1, 2, 3, 4, 5, 6, 7, 8])

Flatten will always convert a tensor into a 1D array. The dimension of the output tensor is 1.

# Multi-dimension tensor

import torch
x = torch.ones(2, 2, 2, 2)
print(x.shape)
print(x)
print()

x = x.flatten()
print(x.shape)
print(x)
print()
torch.Size([2, 2, 2, 2])
tensor([[[[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]]]])

torch.Size([16])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Permute

Permute is used to change the order of the dimensions of a tensor.

The parameter is the new order of the dimensions

So, if we have a tensor of shape (A, B, C, D), the output tensor will have shape (D, C, B, A) given parameter (3, 2, 1, 0).

# Permute demo

import torch

x = torch.ones(2, 3, 4)
print(x.shape)
print(x)
print()

x = x.permute(2, 0, 1)
print(x.shape)
print(x)
print()
torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

torch.Size([4, 2, 3])
tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]])
Back to top