May 30, 2017 · dsp2017 machine learning

Images Augmentation for Deep Learning with Keras

The most common problem when doing deep learning:

How can I find enough data to learn my models?

Indeed, that's a big problem. Gathering high volume of training data for deep learning purposes is hard. But maybe we can "cheat" a bit?

Images Augmentation

Let's talk about images. When doing Deep Learning, we want to have as diverse set of images as possible, to prevent overfitting. Instead of trying to acquire more of them, we can generate additional images based on existing ones, using various transformations.

There's a great tool made for that. It's called ImageDataGenerator and can be found in the Keras library, under keras.preprocessing.image path. Let's assume that we have a single image, called dog.jpg


Now we want to generate additional samples, based on it. The easiest way to achieve this is to run following code (all options can be found here):

from keras.preprocessing.image import ImageDataGenerator  
from keras.preprocessing.image import array_to_img, img_to_array, load_img

input_path = 'dog.jpg'  
output_path = 'dog_random{}.jpg'  
count = 10

gen = ImageDataGenerator(  

# load image to array
image = img_to_array(load_img(input_path))

# reshape to array rank 4
image = image.reshape((1,) + image.shape)

# let's create infinite flow of images
images_flow = gen.flow(image, batch_size=1)  
for i, new_images in enumerate(images_flow):  
    # we access only first image because of batch_size=1
    new_image = array_to_img(new_images[0], scale=True) + 1))
    if i >= count:

After running this code, you'll have 10 different images of dogs, like this:

random_dog random dogs

Usage in deep learning

For training our model, we don't need to save images on disk. There's very convenient method for loading all images from directory and generating random samples based on them:

gen = ImageDataGenerator(  

train_generator = gen.flow_from_directory(  
        target_size=(150, 150),

# now we can use our generator in model fit method

Our images must be strctured in such way that it's possible to read label from their location, for example data/train/dog/dog1.png or data/train/cat/cat1.png. If we have two classes, we should pass class_mode='binary', if more - class_mode='categorical'

You can read more in this awesome Keras Blog article - Building powerful image classification models using very little data.

That's all! If you like this post, please share on social media :)

Comments powered by Disqus