Skip to content

MNIST Digit Prediction using Batch Normalization, Group Normalization, Layer Normalization and L1-L2 Regularizations

Notifications You must be signed in to change notification settings

Arijit-datascience/CNN_BatchNormalization_Regularization

Repository files navigation

EVA6-Normalization-Regularization

Welcome, to learn more about implementation of Normalization and Regularization using Pytorch, please continue reading patiently.

Objective

  1. Write a single model.py file that includes Group Normalization/Layer Normalization/Batch Normalization and takes an argument to decide which Normalization to include.
  2. Write a single notebook file to run all the 3 models above for 20 epochs each, using model.py.
  3. Create these graphs:
    Graph 1: Test/Validation Loss for all 3 models together.
    Graph 2: Test/Validation Accuracy for 3 models together.
    Graphs must have proper annotation.
  4. Find 10 misclassified images for each of the 3 models, and show them as a 5x2 image matrix in 3 separately annotated images.

Lets begin!

Lets understand a bit about the 3 Normalizations that we have used, namely, Batch Normalization, Layer Normalization and Group Normalization.

Consider the following setup

image

We have two layers with batch size of 4, meaning 4 images in each batch. Each of the 4 2x2 matrices under a layer represent a channel.

image

Here while calculating mean and variance, its calculated across the individual channels of each batch, which can be seen in the image above highlighted in blocks of same colour. We have 4 means and variances as we have 4 channels, calculations are done for each channel.

image

For Layer Normalization, we calculate mean and variance across all the channels of the layer, this is highlighted by the red block that spans horizontally across all channels. We have 4 means and variances here as well as we have 4 images and its calculated across all channels of an image.

image

In Group Normalization, each layer is divided into groups. Mean and variance are calculated for these groups, as highlighted by the dotted rectangles. The channels are grouped, and in our case its grouped into 2s. So we end up with 8 groups in all. And hence we have 8 means and variances.

If you are interested, you can check out the complete implementation of whats explained above in an excel sheet HERE

Lets now move onto the implementation part.
We have used MNIST dataset to implement Normalizations.

Pytorch implementation of our experiment is split across two scripts:

  1. Models with all the 3 Normalizations are implemented separately and you can find them in model.py.
  2. Jupyter notebook file with complete end to end implemenation of the 3 experiements which call model.py for the network. Click HERE to view code.

MNIST Digit Recognition

Number of training samples: 60000
Number of test samples: 10000

Transformations Used

  1. Random Rotations
  2. Color Jitter
  3. Image Normalization

Normalization Techniques

  1. Batch Normalization
  2. Group Normalization
  3. Layer Normalization

Regularization

  1. L1 Regularization
    Used Regularization factor of 0.0001. Used only with Batch Normalization.

Observations

  1. Model 1 - Group Normalization
    Train Accuracy: 99.60
    Test Accuracy: 99.54

image

  1. Model 2 - Layer Normalization
    Train Accuracy: 99.61
    Test Accuracy: 99.48

image

  1. Model 3 - Batch Normalization + L1
    Train Accuracy: 99.46
    Test Accuracy: 99.47

image

Conclusions and notes

  1. Best Train and Test Accuracy was achieved with Group Normalization.
  2. Best performance with respect to least difference between Train and Test was achieved by Batch Normalization with L1 Regularization. The added regularization clearly helped reduced overfitting (minor).
  3. The most overfitted among the 3 models was the one with Layer Normalization, although, not by a lot.
  4. Layer Normalization is a special case of Group Normalization wherein we select the group count as 1. As a result, all the channels in the layer will be normalized at once.
    Here, we use nn.GroupNorm(1,num_channels) after each Conv2d layer to implement Layer Normalization.

Training and Validation - Loss & Accuracy

image

image

Misclassified Images

  1. Group Normalization

image

  1. Layer Normalization

image

  1. Batch Normalization + L1

image

Collaborators

Abhiram Gurijala
Arijit Ganguly
Rohin Sequeira

About

MNIST Digit Prediction using Batch Normalization, Group Normalization, Layer Normalization and L1-L2 Regularizations

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •