What is Data Augmentation in a CNN? Python Examples

Algorithms can use machine learning to identify different objects and classify them for image recognition. This evolving technology includes using Data Augmentation to produce better-performing models. Machine learning models need to identify an object in any condition, even if it is rotated, zoomed in, or a grainy image. Researchers needed an artificial way of adding training data with realistic modifications.

Data augmentation is the addition of new data artificially derived from existing training data. Techniques include resizing, flipping, rotating, cropping, padding, etc. It helps to address issues like overfitting and data scarcity, and it makes the model robust with better performance.

Data Augmentation provides many possibilities to alter the original image and can be useful to add enough data for larger models. It is important to learn the techniques of Data Augmentation and its advantages and disadvantages. In this post, I’ll cover all the details you need and show you a Python example using PyTorch.

Data Augmentation in a CNN

Convolutional Neural Networks (CNNs) can do amazing things if there is sufficient data. However, selecting the correct amount of training data for all of the features that need to be trained is a difficult question. If the user does not have enough, the network can overfit on the training data. Realistic images contain a variety of sizes, poses, zoom, lighting, noise, etc.

To make the network robust to these commonly encountered factors, the method of Data Augmentation is used. By rotating input images to different angles, flipping images along different axes, or translating/cropping the images the network will encounter these phenomena during training.

As more parameters are added to a CNN, it requires more examples to show to the machine learning model. Deeper networks can have higher performance, but the tradeoff is increased training data needs and increased training time. [1]

Data Augmentation TechniquesData Augmentation Factor
Flipping2-4x (in each direction)
Salt and Pepper Noise AdditionAt least 2x (depends on the implementation)
A table outlining the factor by which different methods multiply the existing training data.

For this reason, it is also convenient to not have to hunt for or create more images that are suitable for an experiment. Data Augmentation can reduce the cost and effort of increasing the set of available training samples.

Data Augmentation Techniques

Some libraries use Data Augmentation by actually copying the training images and saving these copies as part of the total. This produces new training examples to feed to the machine learning model. Other libraries simply define a set of transforms to perform on the input training data. These transforms are applied randomly. As a result, the space the optimizer is searching is increased. This has the advantage that it does not require extra disk space to augment the training. [2]

Image Data Augmentation is now a famous and common method used with CNNs and involves techniques such as:

  • Flips
  • Rotation (at 90 degrees and finer angles)
  • Translation
  • Scaling
  • Salt and Pepper noise addition

Data Augmentation has even been used in applications like sound recognition. [4] In the next sections, I’ll cover these Data Augmentation methods in detail.


By Flipping images, the optimizer will not become biased that particular features of an image are only on one side. To do this augmentation, the original training image is flipped vertically or horizontally over one axis of the image. As a result, the features continually change directions.

Stella the Puppy sitting on a car seat
Stella the Puppy sitting on a car seat
Stella the Puppy Flipped over the vertical axis.
Stella the Puppy Flipped over the vertical axis.

Flipping is a similar augmentation as rotation, however, it produces mirror images. A particular feature such as the head of a person either stays on top, on the left, on the right, or at the bottom of the image. [3]


Rotation is an augmentation that is commonly performed at 90-degree angles but can even happen at smaller or minute angles if the need for more data is great. For rotation, the background color is commonly fixed so that it can blend when the image is rotated. Otherwise, the model can assume the background change is a distinct feature. This works best when the background is the same in all rotated images. [1]

Stella the Puppy sitting on a car seat
Stella the Puppy sitting on a car seat
Stella the Puppy rotated 90 degrees.
Stella the Puppy rotated 90 degrees.

Specific features move in rotations. For example, the head of a person will be rotated 10, 22.7, or -8 degrees. However, rotation does not change the orientation of the feature and will not produce mirror images like flips. This helps models not consider the angle to be a distinct feature of the human.


Translation of an image means shifting the main object in the image in various directions. For example, consider a person in the center with all their parts visible in the frame and take it as a base image. Next, shift the person to one corner with the legs cut from the bottom as one translated image.

Stella the Puppy sitting on a car seat
Stella the Puppy sitting on a car seat
Stella the Puppy translated and cropped so she's only partly visible.
Stella the Puppy translated and cropped so she’s only partly visible.

Translation ensures that the object is recognized in all parts of the images, and not just in the center or side of the image. By making a variety of these translations, the training data can be augmented so that the network recognizes translated objects. [2]


Scaling provides more diversity in the training data of a machine learning model. Scaling the image will ensure that the object is recognized by the network regardless of how zoomed in or out the image is. Sometimes the object is tiny in the center. Sometimes, the object is zoomed in the image and even cropped at some parts. [3]

Stella the Puppy sitting on a car seat
Stella the Puppy sitting on a car seat
Stella the Puppy scaled up to be even larger than she is in real life.
Stella the Puppy scaled up to be even larger than she is in real life.

Salt and Pepper Noise Addition

Salt and pepper noise addition is the addition of black and white dots (looking like salt and pepper) to the image. This simulates dust and imperfections in real photos. Even if the camera of the photographer is blurry or has spots on it, the image would be better recognized by the model. The training data set is augmented to train the model with more realistic images.

Stella the Puppy sitting on a car seat
Stella the Puppy sitting on a car seat
Stella the Puppy with Salt and Pepper noise added to the image
Stella the Puppy with Salt and Pepper noise added to the image

Benefits of Data Augmentation in a CNN

There are many benefits of using Data Augmentation:

  • Prediction improvement in a model becomes more accurate because Data Augmentation helps in recognizing samples the model has never seen before.
  • There is enough data for the model to understand and train all the available parameters. This can be essential in applications where data collection is difficult.
  • Helps prevent the model from overfitting due to Data Augmentation creating more variety in the data.
  • Can save time in areas where collecting more data is time-consuming.
  • Can reduce the cost required for collecting a variety of data if data collection is costly.

Data Augmentation can be used with other types of Neural Networks and not just CNNs! Want to learn more about Bayesian Neural Networks? Check out my complete tutorial about Bayesian Neural Networks in Python!

Drawbacks of Data Augmentation

Data Augmentation is not useful when the variety required by the application cannot be artificially generated. For example, if one were training a bird recognition model and the training data contained only red birds. The training data could be augmented by generating pictures with the color of the bird varied.

However, the artificial augmentation method may not capture the realistic color details of birds when there is not enough variety of data to start with. For example, if the augmentation method simply varied red for blue or green, etc. Realistic non-red birds may have more complex color variations and the model may fail to recognize the color. Having sufficient data is still important if one wants Data Augmentation to work properly.

Deep artificial neural networks require a large corpus of training data in order to effectively learn, where collection of such training data is often expensive and laborious. Data augmentation overcomes this issue by artificially inflating the training set with label preserving transformations.

Luke and Nitschke [1]

Moreover, underfitting is an issue that can be caused if Data Augmentation is not done right. The number of training epochs must be increased to reflect the additional amount of training data features. If the optimization is not performed over sufficient samples, it may have a sub-optimal configuration.

The biases present in the existing data set will also not be addressed by Data Augmentation. Using the same bird example, it would be difficult to create an artificial augmentation method that creates different species of birds if the training data only contains Eagles.

One famous example of network designers using Data Augmentation is when researchers created the VGG network. Want to learn more about this influential CNN design? Check out my post covering why VGG is so commonly used!

Example Data Augmentation Using PyTorch

In PyTorch, the library chooses to implement Data Augmentation using defined transforms. The main advantage of this method is that it allows the library user to increase the space of inputs being searched by the optimizer during the training process. It also allows the user to get this increase without increasing the amount of data on the disk. For very large amounts of augmentation, the space savings are great.

These transforms are defined on the data set and for each new training epoch, the transform produces a new input result to be searched. Once I learned how this was accomplished, I thought it was a clever way to get the benefits of Data Augmentation without some of the drawbacks. Let’s implement a PyTorch example with some built-in transforms and also define a custom transform.

First, the packages we need for our demo need to be imported. These will allow us to import the Fashion MNIST Dataset using the PyTorch DataLoader. Fashion MNIST is a common dataset for demonstrating Convolutional Neural Networks. The example will use common data augmentation transforms and show how to create a custom transform that will perform Salt and Pepper augmentation.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

from skimage.util import random_noise

Next, the demo code needs to automatically reconfigure if the user has a GPU or CPU.

# Get cpu or gpu device for training.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(DEVICE))

The next step is to define the custom NeuralNetwork object that will be the network we are building and testing. The example shows how the network would be modified if you are following this guide for a Linear network instead of a CNN. The example also shows where Pooling and Dropout should be added if you need those features in your own code.

Want to learn more about Dropout in CNNs? Check out my post about whether you should always use dropout!

class NeuralNetwork(nn.Module):
    Define a demo neural network

    def __init__(self):

        # If using a Linear network
        # self.flatten = nn.Flatten()
        # self.linear_relu_stack = nn.Sequential(
        #     nn.Linear(28 * 28, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 10)
        # )

        # If using a CNN
        self.conv_relu_stack = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.5),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout2d(p=0.5)
        self.linear = nn.Linear(7 * 7 * 64, 10, bias=True)

    def forward(self, sample):
        Computes the outputs of the network from the input sample

        sample : tensor
            The input sample

        out : tensor
            The output of the network
        # For a Linear network
        # sample = self.flatten(sample)
        # out = self.linear_relu_stack(sample)

        # For a CNN
        out = self.conv_relu_stack(sample)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

PyTorch Data Augmentation is carried out by defining transforms for the DataLoader that will expand the space of inputs for the optimizer to search and test. The custom method used in this example is Salt and Pepper noise.

The transform is defined as a custom class. In the __init__ method, any parameters the transform needs are passed into the object and saved. In the __call__ function the image is transformed to a tensor, the noise added, and transformed back to a PIL image that the DataLoader stores.

class SaltPepperTransform:
    Define a custom PyTorch transform to implement 
    Salt and Pepper Data Augmentation

    def __init__(self, amount):
        Pass custom parameters to the transform in init

        amount : float
            The amount of salt and pepper noise to add to the image sample
        self.amount = amount

        # conversion transforms we will use
        self.to_tensor = transforms.ToTensor()
        self.to_pil = transforms.ToPILImage()

    def __call__(self, sample):
        Transform the sample when called

        sample : PIL.Image
            The image to augment with noise

        noise_img : PIL.Image
            The image with noise added
        salt_img = torch.tensor(random_noise(self.to_tensor(sample),
                                             mode='salt', amount=self.amount))

        return self.to_pil(salt_img)

With PyTorch, the DataLoaders are used to import the training and test validation datasets. For common sets like Fashion MNIST, the library allows the user to automatically download the sets if they are not already in the specified folder. If the set is there, the local copy is used.

Any Data Augmentation transforms are specified in the function call. Here all the examples from the above post are used at once, Flips, Rotations, Scaling (Affine), Translation (Affine), and our custom Salt and Pepper. Finally, the image is transformed into a Tensor.

def get_dataloaders():
    Gets the DataLoaders that will be used in the demo

    train_dataloader : torch.utils.data.DataLoader
        The DataLoader containing the training data
    test_dataloader : torch.utils.data.DataLoader
        The DataLoader containing the test data

    # Download training data from open datasets.
    training_data = datasets.FashionMNIST(
        # transform=ToTensor(),
            transforms.RandomAffine(90, (0.3, 0.3), (1.0, 2.0)),

    # Download test data from open datasets.
    test_data = datasets.FashionMNIST(

    train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE)

    test_data = test_dataloader.dataset[0]
    test_sample = test_data[0]
    test_score = test_data[1]
    print('Shape of Test Samples: [N, C, H, W]: {}'.format(test_sample.shape))
    print('Test Score y: {} type: {}'.format(test_score, type(test_score)))

    return train_dataloader, test_dataloader

The next crucial function is our training method. This function uses Stochastic Gradient Descent as the optimizer and uses a Learning Rate of 1e-3. The function sets the network into training mode which will enable some features that are only used in training like Dropout. During training, the function will compute the error (loss) between the predicted output and the recorded training sample.

Afterward, Backpropagation is performed and the network’s weights are updated by the optimizer. The function prints progress updates to the user periodically.

def train_network(dataloader, training_model, loss_function):
    Trains the demo network

    dataloader : torch.utils.data.DataLoader
        The DataLoader with the training data
    training_model : nn.Module
        The network being trained
    loss_function : CrossEntropyLoss
        The loss function used in training

    # Define the optimizer that will train the network model
    optimizer = torch.optim.SGD(training_model.parameters(), lr=1e-3)

    # Set the network in training mode that enables some training-only 
    # layers like dropout

    # Perform the training
    total_training_size = len(dataloader.dataset)
    for batch, (sample, score) in enumerate(dataloader):

        # Copy the sample to the device doing the calculations
        sample, score = sample.to(DEVICE), score.to(DEVICE)

        # Compute prediction error between the model output and the 
        # training score
        prediction = training_model(sample)
        loss = loss_function(prediction, score)

        # Perform backpropagation and update the network weights to minimize 
        # the error

        # Output the score every 100th training iteration so output 
        # window is not spammed
        if batch % 100 == 0:
            loss, current_batch = loss.item(), batch * len(sample)
            print('Current Loss: {} Training Progress: [{} / {}]'.format(
                loss, current_batch, total_training_size))

The final helper function the PyTorch example needs is the test and validation function. This function uses the validation dataset to calculate an out-of-sample metric that measures the performance of how well the network was trained. If the scores are not sufficient, it could be that the network needs to be adjusted and re-trained.

This is an important factor when training a network for applications like Style Transfers. Want to learn more about content loss and how it’s measured? Check out my post about content loss!

The function sets the network into evaluation mode which disables some training-only features like Dropout. In addition, the gradient calculation is disabled which saves resources. The error in prediction for testing samples is calculated, and the performance accuracy and average loss are measured.

def test_network(dataloader, test_model, loss_function):
    Runs the demo network in validation mode

    dataloader : torch.utils.data.DataLoader
        The DataLoader with the validation testing data
    test_model : nn.Module
        The network being validated
    loss_function : CrossEntropyLoss
        The loss function used in validation

    # Set the network into evaluation mode

    # Initializes the counts to zero
    num_correct = 0.0
    total_loss = 0.0

    # Disable the gradient calculation since not training
    with torch.no_grad():
        for sample, score in dataloader:

            # Copy the sample to the device doing the calculations
            sample, score = sample.to(DEVICE), score.to(DEVICE)

            # Compute prediction error between the model output and the 
            # training score
            prediction = test_model(sample)
            total_loss = loss_function(prediction, score).item()

            num_correct += (prediction.argmax(1) == score).type(

    # Calculate the average loss by dividing by the number of batches
    average_loss = total_loss / len(dataloader)

    # Calculate the fraction of correct outputs
    percent_correct = 100.0 * (num_correct / len(dataloader.dataset))
    print('Accuracy: {}%, Average Loss: {}'.format(
        percent_correct, average_loss))

The final step of the example is to implement the main loop. First, the training batch size and number of training epochs are picked. If the performance scores are not sufficient, these parameters can be altered to search for a better result. Next, the model and loss functions are defined. Here, Cross-Entropy Loss is used as one of the common choices.

Finally, the helper functions we made earlier are used. The data sets are loaded, and the training and testing are performed for each epoch. After running, the performance needs to be measured for the user’s application to test different networks and find what methods and architectures perform well.

After training the network, CNNs can be used to output images for tasks like style transfer and generating artwork. Check out my post about CNNs and neural networks outputting images!

if __name__ == '__main__':

    # Set the training batch size and number of training epochs
    BATCH_SIZE = 64
    EPOCHS = 5

    # Get the network model
    model = NeuralNetwork().to(DEVICE)

    # Define the loss function to use and
    loss_func = nn.CrossEntropyLoss()

    # Get the training and testing data. If the data does not exist on 
    # the local disk, it will download and save it. Local copy will 
    # be used on future runs
    train_dl, test_dl = get_dataloaders()

    for i in range(EPOCHS):
        print('Training Epoch: {}'.format(i + 1))
        train_network(train_dl, model, loss_func)
        test_network(test_dl, model, loss_func)

Frequently Asked Questions

Final Thoughts

Now that you know how data augmentation works in a CNN, you can use this technique for your own CNN implementations. Data augmentation can save your model from overfitting and enhance the performance when it is applied. It can also make the training faster by searching a wider input space for the global optimum.

There are multiple techniques in data augmentation to choose from. Depending on the application, it can help build sufficient data to feed the machine learning model. Image data augmentation is one of the best ways to add artificial data without having to spend time and money manually collecting more data.

Data Augmentation is an excellent tool for training networks to do Style Transfers. Want to learn more about Style Transfers and digital art? Check out my beginner guide!

Get Notified When We Publish Similar Articles


    1. Taylor, Luke, and Geoff Nitschke. “Improving deep learning with generic data augmentation.” 2018 IEEE Symposium Series on Computational Intelligence (SSCI). IEEE, 2018.
    2. Li, Wei, et al. “Data augmentation for hyperspectral image classification with deep CNN.” IEEE Geoscience and Remote Sensing Letters 16.4 (2018): 593-597.
    3. Pham, Tri-Cong, et al. “Deep CNN and data augmentation for skin lesion classification.” Asian Conference on Intelligent Information and Database Systems. Springer, Cham, 2018.
    4. Salamon, Justin, and Juan Pablo Bello. “Deep convolutional neural networks and data augmentation for environmental sound classification.” IEEE Signal processing letters 24.3 (2017): 279-283.

    Leave a Comment