⇦ Back

1 Python Packages

The code on this page uses the scikit-learn, NumPy and Matplotlib packages. These can be installed from the terminal with the following commands:

# "python3.12" corresponds to the version of Python you have installed and are using
$ python3.12 -m pip install scikit-learn
$ python3.12 -m pip install numpy
$ python3.12 -m pip install matplotlib

Once finished, import these packages into your Python script as follows:

from sklearn import datasets
from sklearn import neighbors
import numpy as np
from matplotlib import pyplot as plt

2 Example Data

For this example we will use the Iris toy dataset from scikit-learn.

  • See here for the user guide
  • See here for the documentation of the load_iris() function which imports this dataset
  • See here for an example on the scikit-learn site
  • See here for an example on my site
  • See here for the Wikipedia page about this dataset

This dataset can be loaded using the load_iris() function from scikit-learn’s datasets sub-module. Using the as_frame=True option means that the data will be formatted as a Pandas data frame:

# Load the dataset
dataset = datasets.load_iris(as_frame=True)

For this example we will only use one feature: petal length. Note that the feature needs to be reshaped - using the reshape() method - from a 1D array into a 2D array (even though there is only one column) in order for the later functions to work.

# Separate out the feature
X = dataset['data']['petal length (cm)'].values.reshape(-1, 1)

# View the first five values
print(X[:5])
## [[1.4]
##  [1.4]
##  [1.3]
##  [1.5]
##  [1.4]]

The target will be the species of Iris flower that corresponds to each specimen:

# Separate out the target
y = dataset['target']

By default the target values are the numbers 0, 1 and 2, but these can be ‘translated’ into their actual values (the species names ‘setosa’, ‘virginica’ and ‘versicolor’) which are stored in the 'target_names' key:

# Translate the target
y = y.apply(lambda x: dataset['target_names'][x])

# View the first five values
print(y.head())
## 0    setosa
## 1    setosa
## 2    setosa
## 3    setosa
## 4    setosa
## Name: target, dtype: object

3 Visualise the Data

Now let’s have a look at the raw data:

# Plot
ax = plt.axes()
ax.scatter(X, y, alpha=0.2)
ax.set_yticks([0, 1, 2])
ax.set_yticklabels([s.title() for s in dataset['target_names']])
ax.set_title('Classifying Species from Petal Length')
ax.set_ylabel('Species')
ax.set_xlabel('Petal Length (cm)')
plt.tight_layout()
plt.show()

As you can see, the petals of Iris virginica flowers tend to be longer than those of Iris versicolor flowers with Iris setosa petals being the shortest. What we want to do now is build a model that will predict whether a flower is a virginica, versicolor or setosa based on its petal length alone, and we will use a k-nearest neighbors (k-NN) algorithm to do so.

4 Overview of the Algorithm

There are three main steps in a k-NN algorithm:

  1. The data is normalised so that all values lie between 0 and 1: (x - minimum) / (maximum - minimum)
  2. When given a new point, the ‘distance’ to every other point is calculated. This can be the Euclidean distance, but it doesn’t have to be. For example, the squared difference between the points can be used: squared_difference = (x_1[i] - x_2[i]) ** 2
  3. The new point is classified as being in the same category as the majority of its k-nearest neighbours. For example, in a three-nearest neighbours algorithm, if two of a point’s three nearest neighbours are virginica flowers then the point will be classified as a virginica flower as well.

5 Create the Model

This will use the KNeighborsClassifier() function from scikit-learn. See the documentation, the user guide and an example for more info. The value of k (ie the number of neighbours) will be 3:

# Create a model and fit it to the data
model = neighbors.KNeighborsClassifier(n_neighbors=3)
model.fit(X, y)

6 Make Predictions

If we are given three new Iris flowers with petal lengths of 2, 3 and 5 centimeters what species do we think they are?

# Make a prediction
petal_length = [2, 3, 5]
y_pred = model.predict(np.array(petal_length).reshape(-1, 1))

print(y_pred)
## ['setosa' 'versicolor' 'virginica']

We need to be careful that we use data of the right shape. In the above example we used the .reshape() method to achieve this but we could alternatively have used a list-of-lists as our input:

# Make a prediction
petal_length = [[2], [3], [5]]
y_pred = model.predict(petal_length)

print(y_pred)
## ['setosa' 'versicolor' 'virginica']

7 Find the Decision Boundaries

As this is quite a simple example (there is only one feature) the decision boundaries are just numbers which represent the petal lengths above and below which a flower is classified differently:

# Find decision boundaries
x_fitted = np.linspace(X.min(), X.max(), 500).reshape(-1, 1)
y_fitted = model.predict(x_fitted)
for species in dataset['target_names']:
    x_species = x_fitted[y_fitted == species]
    minimum = round(x_species.min(), 2)
    maximum = round(x_species.max(), 2)
    print(f'{species.title()} flowers have petals between {minimum} and {maximum} cm in length')
## Setosa flowers have petals between 1.0 and 2.6 cm in length
## Versicolor flowers have petals between 2.61 and 4.95 cm in length
## Virginica flowers have petals between 4.96 and 6.9 cm in length

8 Visualise the Model

We can re-plot the data with the decision boundaries added in:

# Plot
ax = plt.axes()
ax.axvline(2.605, c='grey', ls='--')
ax.axvline(4.955, c='grey', ls='--')
ax.scatter(X, y, alpha=0.2)
ax.set_yticks([0, 1, 2])
ax.set_yticklabels([s.title() for s in dataset['target_names']])
ax.set_title('Classifying Species from Petal Length')
ax.set_ylabel('Species')
ax.set_xlabel('Petal Length (cm)')
plt.tight_layout()
plt.show()

This model appears to be very good at identifying setosa flowers, but there is a bit of overlap between versicolor and virginica flowers that could result in inaccurate predictions.

⇦ Back