The code on this page uses the scikit-learn, NumPy and Matplotlib packages. These can be installed from the terminal with the following commands:
# "python3.12" corresponds to the version of Python you have installed and are using
$ python3.12 -m pip install scikit-learn
$ python3.12 -m pip install numpy
$ python3.12 -m pip install matplotlib
Once finished, import these packages into your Python script as follows:
from sklearn import datasets
from sklearn import neighbors
import numpy as np
from matplotlib import pyplot as plt
For this example we will use the Iris toy dataset from scikit-learn.
load_iris()
function which
imports this datasetThis dataset can be loaded using the load_iris()
function from scikit-learn’s datasets
sub-module. Using the
as_frame=True
option means that the data will be formatted
as a Pandas data frame:
# Load the dataset
dataset = datasets.load_iris(as_frame=True)
For this example we will only use one feature: petal length. Note
that the feature needs to be reshaped - using the reshape()
method - from a 1D array into a 2D array (even though there is only one
column) in order for the later functions to work.
# Separate out the feature
X = dataset['data']['petal length (cm)'].values.reshape(-1, 1)
# View the first five values
print(X[:5])
## [[1.4]
## [1.4]
## [1.3]
## [1.5]
## [1.4]]
The target will be the species of Iris flower that corresponds to each specimen:
# Separate out the target
y = dataset['target']
By default the target values are the numbers 0, 1 and 2, but these
can be ‘translated’ into their actual values (the species names
‘setosa’, ‘virginica’ and ‘versicolor’) which are stored in the
'target_names'
key:
# Translate the target
y = y.apply(lambda x: dataset['target_names'][x])
# View the first five values
print(y.head())
## 0 setosa
## 1 setosa
## 2 setosa
## 3 setosa
## 4 setosa
## Name: target, dtype: object
Now let’s have a look at the raw data:
# Plot
ax = plt.axes()
ax.scatter(X, y, alpha=0.2)
ax.set_yticks([0, 1, 2])
ax.set_yticklabels([s.title() for s in dataset['target_names']])
ax.set_title('Classifying Species from Petal Length')
ax.set_ylabel('Species')
ax.set_xlabel('Petal Length (cm)')
plt.tight_layout()
plt.show()
As you can see, the petals of Iris virginica flowers tend to be longer than those of Iris versicolor flowers with Iris setosa petals being the shortest. What we want to do now is build a model that will predict whether a flower is a virginica, versicolor or setosa based on its petal length alone, and we will use a k-nearest neighbors (k-NN) algorithm to do so.
There are three main steps in a k-NN algorithm:
(x - minimum) / (maximum - minimum)
squared_difference = (x_1[i] - x_2[i]) ** 2
This will use the KNeighborsClassifier()
function from
scikit-learn. See the documentation,
the user
guide and an example
for more info. The value of k (ie the number of neighbours)
will be 3:
# Create a model and fit it to the data
model = neighbors.KNeighborsClassifier(n_neighbors=3)
model.fit(X, y)
If we are given three new Iris flowers with petal lengths of 2, 3 and 5 centimeters what species do we think they are?
# Make a prediction
petal_length = [2, 3, 5]
y_pred = model.predict(np.array(petal_length).reshape(-1, 1))
print(y_pred)
## ['setosa' 'versicolor' 'virginica']
We need to be careful that we use data of the right shape. In the
above example we used the .reshape()
method to achieve this
but we could alternatively have used a list-of-lists as our input:
# Make a prediction
petal_length = [[2], [3], [5]]
y_pred = model.predict(petal_length)
print(y_pred)
## ['setosa' 'versicolor' 'virginica']
As this is quite a simple example (there is only one feature) the decision boundaries are just numbers which represent the petal lengths above and below which a flower is classified differently:
# Find decision boundaries
x_fitted = np.linspace(X.min(), X.max(), 500).reshape(-1, 1)
y_fitted = model.predict(x_fitted)
for species in dataset['target_names']:
x_species = x_fitted[y_fitted == species]
minimum = round(x_species.min(), 2)
maximum = round(x_species.max(), 2)
print(f'{species.title()} flowers have petals between {minimum} and {maximum} cm in length')
## Setosa flowers have petals between 1.0 and 2.6 cm in length
## Versicolor flowers have petals between 2.61 and 4.95 cm in length
## Virginica flowers have petals between 4.96 and 6.9 cm in length
We can re-plot the data with the decision boundaries added in:
# Plot
ax = plt.axes()
ax.axvline(2.605, c='grey', ls='--')
ax.axvline(4.955, c='grey', ls='--')
ax.scatter(X, y, alpha=0.2)
ax.set_yticks([0, 1, 2])
ax.set_yticklabels([s.title() for s in dataset['target_names']])
ax.set_title('Classifying Species from Petal Length')
ax.set_ylabel('Species')
ax.set_xlabel('Petal Length (cm)')
plt.tight_layout()
plt.show()
This model appears to be very good at identifying setosa flowers, but there is a bit of overlap between versicolor and virginica flowers that could result in inaccurate predictions.