⇦ Back

Reduce the number of colours in an image using Python.

1 Python Packages

The code on this page uses the NumPy, Matplotlib and Pillow packages which can be installed from the terminal via the following:

# "python3.12" should correspond to the version of Python you are using
python3.12 -m pip install numpy
python3.12 -m pip install matplotlib
python3.12 -m pip install pillow

Once finished, import these packages along with the pre-installed os module into your Python script as follows:

# NumPy is the fundamental package for scientific computing with Python
import numpy as np
# Matplotlib is for creating static, animated and interactive visualizations
from matplotlib import pyplot as plt
# Pillow is a fork of the Python Imaging Library (PIL) for image processing
from PIL import Image, ImageOps
# os provides a portable way of using operating system dependent functionality
import os

If you’re on an Ubuntu machine or similar it’s possible that you will need to change some environment variables to be compatible with the Wayland system:

# Set the QT_QPA_PLATFORM environment variable to wayland
os.environ['QT_QPA_PLATFORM'] = 'wayland'
# Set the Matplotlib backend to one that is compatible with Wayland
plt.switch_backend('Agg')

2 Import an Image

Assuming that there is an image file called malaria.jpg in the same folder as your script, import it via the following:

# Import image
filename = 'malaria.jpg'
img = plt.imread(filename)
plt.xticks([])
plt.yticks([])
plt.imshow(img)

3 Plot a Histogram

Let’s take a look at how many pixels of each colour are in our image by plotting their colour values from 0 (black) to 255 (white) in a histogram:

# Plot histogram
plt.hist(img.flatten(), 256, range=(0.0, 256.0), fc='black', ec='black')
plt.title('Histogram')
plt.ylabel('Count')
plt.xlabel('Value')
plt.show()

The next step is to decide how many colours we want to decompose our image into and choose one ‘seed point’ for each of these colours.

4 Select Seed Points

Let’s decompose our image into three colours. We thus need to start with three seed points.

These seed points should ideally correspond to peaks in the above histogram, but the points we start with do not necessarily have to be accurate because we will optimise them later. For this example, I will purposefully choose quite bad seed points:

# Seed points (chosen by looking at the histogram)
seed_point_0 = 100
seed_point_1 = 180
seed_point_2 = 225

# Plot histogram
plt.hist(img.flatten(), 256, range=(0.0, 256.0), fc='black', ec='black')
# Draw vertical lines at the estimated seed points
plt.axvline(seed_point_0) 
plt.axvline(seed_point_1)
plt.axvline(seed_point_2)
plt.title('Histogram')
plt.ylabel('Count')
plt.xlabel('Value')
plt.show()

These blue lines - corresponding to the three seed points - are quite far away from the peaks in the data.

5 Optimising

We will use a k-means clustering algorithm to cluster our 3 means:

def k_means_clustering(img, seed_points):
    """
    Parameters
    ----------
    img : numpy.ndarray
        2D array of values representing the (greyscale) pixels' colour values.
    seed_points : list
        Estimates of the means of each cluster.

    Returns
    -------
    filtered : numpy.ndarray
        The filtered image.
    """
    # Select all of the image's x-axis and y-axis and one layer of its z-axes
    img = img[:, :, 0].copy()
    image_size = np.shape(img)

    # Initial guesses at the clusters' means
    mean_0 = seed_points[0]
    mean_1 = seed_points[1]
    mean_2 = seed_points[2]

    # Initialise counts of elements in clusters before assignments take place
    old_0 = 1 
    old_1 = 1
    old_2 = 1

    # Initialise counts of elements in clusters after assignments take place
    new_0 = 2
    new_1 = 2
    new_2 = 2

    # Use an iterative method to get more accurate values for the clusters'
    # means: keep going while there are still differences between the numbers
    # of elements in each cluster before and after assignments
    while new_0 != old_0 and new_1 != old_1 and new_2 != old_2:
        # Initialise clusters
        cluster_0 = []
        cluster_1 = []
        cluster_2 = []
        # Update count of elements in clusters before assignments take place
        old_0 = new_0
        old_1 = new_1
        old_2 = new_2
        # For each row in the image
        for row in range(image_size[0]):
            # For each column in the image
            for column in range(image_size[1]):
                # Calculate the pixel's distances to the seed points
                dist_to_seed_0 = abs(img[row, column] - mean_0)
                dist_to_seed_1 = abs(img[row, column] - mean_1)
                dist_to_seed_2 = abs(img[row, column] - mean_2)
                # Calculate the smallest of the distances
                distances = [dist_to_seed_0, dist_to_seed_1, dist_to_seed_2]
                min_distance = np.min(distances)
                if dist_to_seed_0 == min_distance:
                    # If the pixel is closest to the first seed point place it
                    # in cluster 0
                    cluster_0.append(img[row, column])
                elif dist_to_seed_1 == min_distance:
                    # If the pixel is closest to the second seed point place it
                    # in cluster 1
                    cluster_1.append(img[row, column])
                elif dist_to_seed_2 == min_distance:
                    # If the pixel is closest to the third seed point place it
                    # in cluster 2
                    cluster_2.append(img[row, column])
                else:
                    raise ValueError
        # Calculate the means of the clusters
        mean_0 = np.mean(cluster_0, axis=None)
        mean_1 = np.mean(cluster_1, axis=None)
        mean_2 = np.mean(cluster_2, axis=None)
        # Count the elements in each clusters after assignments
        new_0 = len(cluster_0)
        new_1 = len(cluster_1)
        new_2 = len(cluster_2)

    # Now that accurate seed points have been obtained, change each pixel to
    # the value of it's seed point
    for row in range(image_size[0]):
        for col in range(image_size[1]):
            # Calculate the pixel's distances to the seed points
            dist_to_seed_0 = abs(img[row, col] - mean_0)
            dist_to_seed_1 = abs(img[row, col] - mean_1)
            dist_to_seed_2 = abs(img[row, col] - mean_2)
            # Calculate the smallest of the distances
            distances = [dist_to_seed_0, dist_to_seed_1, dist_to_seed_2]
            min_distance = np.min(distances)
            # If the pixel is closest to seed point x set it equal to cluster
            # x's mean
            if dist_to_seed_0 == min_distance:
                img[row, col] = mean_0
            elif dist_to_seed_1 == min_distance:
                img[row, col] = mean_1
            elif dist_to_seed_2 == min_distance:
                img[row, col] = mean_2
            else:
                raise ValueError

    return img, [mean_0, mean_1, mean_2]


seed_points = [seed_point_0, seed_point_1, seed_point_2]
filtered, means = k_means_clustering(img, seed_points)

Display the posterised image:

plt.xticks([])
plt.yticks([])
plt.imshow(filtered, cmap='gray')

6 Check the Optimisation

Re-plot the histograms to check how well the optimisation has worked:

# Plot histogram
plt.hist(img.flatten(), 256, range=(0.0, 256.0), fc='black', ec='black')
# Draw vertical lines at the estimated seed points
plt.axvline(seed_point_0) 
plt.axvline(seed_point_1)
plt.axvline(seed_point_2)
# Draw vertical lines at the actual seed points
plt.axvline(means[0], color='r')
plt.axvline(means[1], color='r')
plt.axvline(means[2], color='r')
plt.title('Histogram')
plt.ylabel('Count')
plt.xlabel('Value')
plt.show()

These optimised seed points - in red - lie in the middles of the three distinct clusters in the histogram. Our initialised seed points - in blue - are included for reference. The small cluster that lies below 100 on the plot is very small compared to the other two, but this cluster corresponds to the near-black pixels in the original image which are indeed important features. So while the cluster is small in height, the fact that it’s distinct from the other two is what’s important.

7 Posterising in Colour

Let’s try with a colour image:

# Show image
img = Image.open('arvee_marie.jpg')
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap='gray')

Instead of use the above function we created, we will just use the built-in functionality from Pillow:

# Create list with the names of the files to be posterised
filenames = ['arvee_marie.jpg']

for filename in filenames:
    # Import
    img = Image.open(filename)
    # Posterise image
    img = ImageOps.posterize(img, 1)
    # Show image
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img, cmap='gray')

This has reduced the image to have one bit for each channel which corresponds to seven colours.

⇦ Back