Reduce the number of colours in an image using Python.
The code on this page uses the NumPy, Matplotlib and Pillow packages which can be installed from the terminal via the following:
# "python3.12" should correspond to the version of Python you are using
python3.12 -m pip install numpy
python3.12 -m pip install matplotlib
python3.12 -m pip install pillow
Once finished, import these packages along with the pre-installed os
module into your Python script as follows:
# NumPy is the fundamental package for scientific computing with Python
import numpy as np
# Matplotlib is for creating static, animated and interactive visualizations
from matplotlib import pyplot as plt
# Pillow is a fork of the Python Imaging Library (PIL) for image processing
from PIL import Image, ImageOps
# os provides a portable way of using operating system dependent functionality
import os
If you’re on an Ubuntu machine or similar it’s possible that you will need to change some environment variables to be compatible with the Wayland system:
# Set the QT_QPA_PLATFORM environment variable to wayland
os.environ['QT_QPA_PLATFORM'] = 'wayland'
# Set the Matplotlib backend to one that is compatible with Wayland
plt.switch_backend('Agg')
Assuming that there is an image file called malaria.jpg
in the same folder as your script, import it via the following:
# Import image
filename = 'malaria.jpg'
img = plt.imread(filename)
plt.xticks([])
plt.yticks([])
plt.imshow(img)
Let’s take a look at how many pixels of each colour are in our image by plotting their colour values from 0 (black) to 255 (white) in a histogram:
# Plot histogram
plt.hist(img.flatten(), 256, range=(0.0, 256.0), fc='black', ec='black')
plt.title('Histogram')
plt.ylabel('Count')
plt.xlabel('Value')
plt.show()
The next step is to decide how many colours we want to decompose our image into and choose one ‘seed point’ for each of these colours.
Let’s decompose our image into three colours. We thus need to start with three seed points.
These seed points should ideally correspond to peaks in the above histogram, but the points we start with do not necessarily have to be accurate because we will optimise them later. For this example, I will purposefully choose quite bad seed points:
# Seed points (chosen by looking at the histogram)
seed_point_0 = 100
seed_point_1 = 180
seed_point_2 = 225
# Plot histogram
plt.hist(img.flatten(), 256, range=(0.0, 256.0), fc='black', ec='black')
# Draw vertical lines at the estimated seed points
plt.axvline(seed_point_0)
plt.axvline(seed_point_1)
plt.axvline(seed_point_2)
plt.title('Histogram')
plt.ylabel('Count')
plt.xlabel('Value')
plt.show()
These blue lines - corresponding to the three seed points - are quite far away from the peaks in the data.
We will use a k-means clustering algorithm to cluster our 3 means:
def k_means_clustering(img, seed_points):
"""
Parameters
----------
img : numpy.ndarray
2D array of values representing the (greyscale) pixels' colour values.
seed_points : list
Estimates of the means of each cluster.
Returns
-------
filtered : numpy.ndarray
The filtered image.
"""
# Select all of the image's x-axis and y-axis and one layer of its z-axes
img = img[:, :, 0].copy()
image_size = np.shape(img)
# Initial guesses at the clusters' means
mean_0 = seed_points[0]
mean_1 = seed_points[1]
mean_2 = seed_points[2]
# Initialise counts of elements in clusters before assignments take place
old_0 = 1
old_1 = 1
old_2 = 1
# Initialise counts of elements in clusters after assignments take place
new_0 = 2
new_1 = 2
new_2 = 2
# Use an iterative method to get more accurate values for the clusters'
# means: keep going while there are still differences between the numbers
# of elements in each cluster before and after assignments
while new_0 != old_0 and new_1 != old_1 and new_2 != old_2:
# Initialise clusters
cluster_0 = []
cluster_1 = []
cluster_2 = []
# Update count of elements in clusters before assignments take place
old_0 = new_0
old_1 = new_1
old_2 = new_2
# For each row in the image
for row in range(image_size[0]):
# For each column in the image
for column in range(image_size[1]):
# Calculate the pixel's distances to the seed points
dist_to_seed_0 = abs(img[row, column] - mean_0)
dist_to_seed_1 = abs(img[row, column] - mean_1)
dist_to_seed_2 = abs(img[row, column] - mean_2)
# Calculate the smallest of the distances
distances = [dist_to_seed_0, dist_to_seed_1, dist_to_seed_2]
min_distance = np.min(distances)
if dist_to_seed_0 == min_distance:
# If the pixel is closest to the first seed point place it
# in cluster 0
cluster_0.append(img[row, column])
elif dist_to_seed_1 == min_distance:
# If the pixel is closest to the second seed point place it
# in cluster 1
cluster_1.append(img[row, column])
elif dist_to_seed_2 == min_distance:
# If the pixel is closest to the third seed point place it
# in cluster 2
cluster_2.append(img[row, column])
else:
raise ValueError
# Calculate the means of the clusters
mean_0 = np.mean(cluster_0, axis=None)
mean_1 = np.mean(cluster_1, axis=None)
mean_2 = np.mean(cluster_2, axis=None)
# Count the elements in each clusters after assignments
new_0 = len(cluster_0)
new_1 = len(cluster_1)
new_2 = len(cluster_2)
# Now that accurate seed points have been obtained, change each pixel to
# the value of it's seed point
for row in range(image_size[0]):
for col in range(image_size[1]):
# Calculate the pixel's distances to the seed points
dist_to_seed_0 = abs(img[row, col] - mean_0)
dist_to_seed_1 = abs(img[row, col] - mean_1)
dist_to_seed_2 = abs(img[row, col] - mean_2)
# Calculate the smallest of the distances
distances = [dist_to_seed_0, dist_to_seed_1, dist_to_seed_2]
min_distance = np.min(distances)
# If the pixel is closest to seed point x set it equal to cluster
# x's mean
if dist_to_seed_0 == min_distance:
img[row, col] = mean_0
elif dist_to_seed_1 == min_distance:
img[row, col] = mean_1
elif dist_to_seed_2 == min_distance:
img[row, col] = mean_2
else:
raise ValueError
return img, [mean_0, mean_1, mean_2]
seed_points = [seed_point_0, seed_point_1, seed_point_2]
filtered, means = k_means_clustering(img, seed_points)
Display the posterised image:
plt.xticks([])
plt.yticks([])
plt.imshow(filtered, cmap='gray')
Re-plot the histograms to check how well the optimisation has worked:
# Plot histogram
plt.hist(img.flatten(), 256, range=(0.0, 256.0), fc='black', ec='black')
# Draw vertical lines at the estimated seed points
plt.axvline(seed_point_0)
plt.axvline(seed_point_1)
plt.axvline(seed_point_2)
# Draw vertical lines at the actual seed points
plt.axvline(means[0], color='r')
plt.axvline(means[1], color='r')
plt.axvline(means[2], color='r')
plt.title('Histogram')
plt.ylabel('Count')
plt.xlabel('Value')
plt.show()
These optimised seed points - in red - lie in the middles of the three distinct clusters in the histogram. Our initialised seed points - in blue - are included for reference. The small cluster that lies below 100 on the plot is very small compared to the other two, but this cluster corresponds to the near-black pixels in the original image which are indeed important features. So while the cluster is small in height, the fact that it’s distinct from the other two is what’s important.
Let’s try with a colour image:
# Show image
img = Image.open('arvee_marie.jpg')
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap='gray')
Instead of use the above function we created, we will just use the built-in functionality from Pillow:
# Create list with the names of the files to be posterised
filenames = ['arvee_marie.jpg']
for filename in filenames:
# Import
img = Image.open(filename)
# Posterise image
img = ImageOps.posterize(img, 1)
# Show image
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap='gray')
This has reduced the image to have one bit for each channel which corresponds to seven colours.