⇦ Back

1 The Basics

A bar plot is created on a set of axes in Python by using the ax.bar() function from Matplotlib. The first argument passed to this function is the x-positions of the bars and the second argument is the heights:

import matplotlib.pyplot as plt

x_positions = [0, 1, 2]
heights = [6, 12, 9]

ax = plt.axes()
ax.bar(x_positions, heights)
plt.show()
plt.close()

1.1 Plot Title, Axis Labels, Bar Widths

Titles and labels are added using set_title(), set_xlabel() and set_ylabel() while the width of the bars is defined by the third argument in the ax.bar() function. Use \n to have a linebreak in a title:

x_positions = [0, 1, 2]
heights = [6, 12, 9]
width = 0.5

ax = plt.axes()
ax.bar(x_positions, heights, width)
ax.set_title('How bar plots\nwork in Matplotlib')
ax.set_xlabel('The first argument of the ax.bar() function')
ax.set_ylabel('The second argument of the ax.bar() function')
plt.show()
plt.close()

1.2 Automatically Get the x-Positions

This is exactly the same example as the previous one, except this time Numpy’s arange() function is being used to automatically generate the [0, 1, 2] x-positions using the length of the heights variable (which happens to be the number of bars that are going to be plotted):

import numpy as np

heights = [6, 12, 9]
x_positions = np.arange(len(heights))
width = 0.5

ax = plt.axes()
ax.bar(x_positions, heights, width)
ax.set_title('How bar plots work in Matplotlib')
ax.set_xlabel('The first argument of the ax.bar() function')
ax.set_ylabel('The second argument of the ax.bar() function')
plt.show()
plt.close()

1.3 Tick Labels

If the data’s independent variable is qualitative, the tick labels on the x-axis can be changed using:

  • The ax.set_xticks() command to define the positions of the ticks
  • The ax.set_xticklabels() command to define the text for the tick labels

Note that both of these commands need to be used, in the order shown above, to get the desired output. Both functions take lists as their inputs.

Further customisation of the tick labels can be done using the ax.tick_params() function which has keyword arguments such as:

  • axis to choose the axis whose tick labels you want to edit
  • length to set the length of the ticks in points
  • rotation to change the angle at which the tick labels are written
  • labelsize to set the size of the tick labels. This argument will accept either a number (which gets interpreted as a font size with ‘points’ as the unit) or a string description (eg “large” or “small”).
import numpy as np

heights = [6, 12, 9]
x_positions = np.arange(len(heights))
width = 0.5

ax = plt.axes()
ax.bar(x_positions, heights, width)
ax.set_title('How bar plots work in Matplotlib')
ax.set_xlabel('The first argument of the ax.bar() function')
ax.set_ylabel('The second argument of the ax.bar() function')
ax.set_xticks(x_positions)
ax.set_xticklabels(['Data Point 1', 'Data Point 2', 'Data Point 3'])
ax.tick_params(axis='x', length=3, rotation=10, labelsize='small')
plt.show()
plt.close()

2 Plotting a Dataset

It’s not often that you will have your data in the perfect format with the x-positions and heights in their own variables already. Most times, you will need to do some level of data manipulation before you can plot. Here’s an example using the iris dataset from the scikit-learn package (see here for more info on these toy datasets):

from sklearn.datasets import load_iris

# Load the dataset
iris = load_iris()

This dataset contains data from three different species of the iris flower:

print(iris['target_names'])
## ['setosa' 'versicolor' 'virginica']

For this example we only want one of these three groups, so extract the data for the ‘versicolor’ species only:

# Get the number corresponding to the versicolor species
species = np.where(iris['target_names'] == 'versicolor')
# Lookup this number in the 'target' column
idx = [i for i, v in enumerate(iris['target']) if v == species]
# Filter to get only the rows with this number
data = iris['data'][idx]

The information in this array is divided into four columns: sepal length, sepal width, petal length and petal width. The first 10 rows (of 50) are below:

print(data[:10])
## [[7.  3.2 4.7 1.4]
##  [6.4 3.2 4.5 1.5]
##  [6.9 3.1 4.9 1.5]
##  [5.5 2.3 4.  1.3]
##  [6.5 2.8 4.6 1.5]
##  [5.7 2.8 4.5 1.3]
##  [6.3 3.3 4.7 1.6]
##  [4.9 2.4 3.3 1. ]
##  [6.6 2.9 4.6 1.3]
##  [5.2 2.7 3.9 1.4]]

Extract only the petal lengths (the third column, ie the one at index 2):

# Extract the 'petal length (cm)' data
petal_length = data[:, 2]
print(petal_length)
## [4.7 4.5 4.9 4.  4.6 4.5 4.7 3.3 4.6 3.9 3.5 4.2 4.  4.7 3.6 4.4 4.5 4.1
##  4.5 3.9 4.8 4.  4.9 4.7 4.3 4.4 4.8 5.  4.5 3.5 3.8 3.7 3.9 5.1 4.5 4.5
##  4.7 4.4 4.1 4.  4.4 4.6 4.  3.3 4.2 4.2 4.2 4.3 3.  4.1]

Now that we have the data we want, there are two ways to plot it:

2.1 Plotting Values

The first way is to do what we’ve done before: plot the numbers as they are:

# Extract the data
heights = petal_length
x_positions = np.arange(len(heights))

# Plot
ax = plt.axes()
ax.bar(x_positions, heights)
ax.set_title('The lengths of the petals of 50 iris versicolor flowers')
ax.set_xlabel('')
ax.set_ylabel('Length (cm)')
plt.show()
plt.close()

This code has worked as expected, but it’s probably not the graph that we want. It would be much more useful to have a bar plot of the number of petals of each length. For that we need to plot the counts, ie the number of occurrences of each petal length:

2.2 Plotting Counts (aka Histograms)

The number of times each value appears in the ‘petal length’ dataset can be counted with the unique() function from Numpy using the return_counts option:

# Find the unique values in the dataset and count their occurrences
uniq, cnts = np.unique(petal_length, return_counts=1)

# Plotting counts (creating a histogram with smallest bin width)
x_positions = uniq
heights = cnts
width = 0.1

# Plot
ax = plt.axes()
ax.bar(x_positions, heights, width)
ax.set_title('The lengths of the petals of 50 iris versicolor flowers')
ax.set_xlabel('Length (cm)')
ax.set_ylabel('Count')
plt.show()
plt.close()

3 Data Types

The first examples used lists while the last examples (the ones that used the iris dataset) used arrays. It makes no difference; the ax.bar() function treats them both the same. It also treats Pandas data frames and series the same, as shown in this example where the iris dataset is first converted into a data frame before being plotted:

import pandas as pd

# Load the data set
iris = load_iris()
# Convert the array to a data frame
iris_df = pd.DataFrame(iris['data'], columns=iris['feature_names'])
# Add the species data as a column to the data frame
iris_df['Species'] = iris['target']
# Get the value that corresponds to the versicolor species
species = np.where(iris['target_names'] == 'versicolor')
# Convert tuple to integer
species = species[0][0]
# Filter to only have data from the versicolor plant
data = iris_df[iris_df['Species'] == species]
# Count the number of occurrences of each petal length
uniq, cnts = np.unique(data['petal length (cm)'], return_counts=1)

# Consolidate the data you want to plot
x_positions = uniq
heights = cnts
width = 0.1

# Plot
ax = plt.axes()
ax.bar(x_positions, heights, width)
ax.set_title('Bar plot using a Pandas data frame')
ax.set_xlabel('Petal Lengths (cm)')
ax.set_ylabel('Count')
plt.show()
plt.close()

4 Using Latex and Annotations

It’s possible to use Latex for the text in the plot, which brings with it the ability to use Greek letters, equations, Unicode symbols and the like. See here for more info.

# Settings
x = 6  # Want figure to be A6
plt.rc('figure', figsize=[46.82 * .5**(.5 * x), 33.11 * .5**(.5 * x)])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('text.latex', preamble=r'\usepackage{textgreek}')

# Plot
ax = plt.axes()
ax.bar(x_positions, heights, width)
ax.set_title(r'Using \LaTeX')
ax.set_xlabel('Petal Lengths (cm)')
ax.set_ylabel(r'Count (\textSigma)')
plt.annotate(r'We can create equations if we want:', (3, 6))
plt.annotate(r'$\Sigma = \frac{x}{2}$', (3, 5.5))
plt.show()
plt.close()

5 Using Colours

Colours are specified using the color keyword argument. Colours needs to be provided in a list and, if the number of bars is larger than the number of provided colours, it will wrap around (for example, if there are five bars and only three colours are provided, the fourth and fifth bars will have the same colours as the first and second, respectively).

5.1 Defined Colours

# Defined colours
colours = ['red', 'green']

# Plot
ax = plt.axes()
ax.bar(x_positions, heights, width, color=colours)
ax.set_title('Customising the Colours')
ax.set_xlabel('Petal Lengths (cm)')
ax.set_ylabel('Count')
plt.show()
plt.close()

5.2 Colour Palette

# Colour Palette
ax = plt.axes()
ax.bar(x_positions, heights, width, color=['C0', 'C1', 'C2'])
ax.set_title('Customising the Colours')
ax.set_xlabel('Petal Lengths (cm)')
ax.set_ylabel('Count')
plt.show()
plt.close()

5.3 Custom Colours

# Custom Colours
pink = '#FB4188'
green = '#87C94A'
blue = '#39C2F3'
yellow = '#FADB39'
lgrey = '#798287'
dgrey = '#43454C'
colours = [pink, green, blue, yellow, lgrey, dgrey]

# Plot
ax = plt.axes()
ax.bar(x_positions, heights, width, color=colours)
ax.set_title('Customising the Colours')
ax.set_xlabel('Petal Lengths (cm)')
ax.set_ylabel('Count')
plt.show()
plt.close()

6 Outlines

Getting outlines on the bars can be done using keywords arguments in the bar() call:

  • edgecolor sets the colour
  • linewidth sets the width
import numpy as np
import matplotlib.pyplot as plt

# Settings
x = 5  # Want figures to be A5
plt.rc('figure', figsize=[46.82 * .5**(.5 * x), 33.11 * .5**(.5 * x)])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('text.latex', preamble=r'\usepackage{textgreek}')

# Custom Colours
skyblue = '#87CEEB'
maroon = '#800000'

# Create data
london_rainfall = {
    'Jan': 55.2, 'Feb': 40.9, 'Mar': 41.6, 'Apr': 43.7, 'May': 49.4, 'Jun': 45.1,
    'Jul': 44.5, 'Aug': 49.5, 'Sep': 49.1, 'Oct': 68.5, 'Nov': 59.0, 'Dec': 55.2,
}

#
# Plot
#
ax = plt.axes()
for i, item in enumerate(london_rainfall.items()):
    rainfall = item[1]
    bar = ax.bar(i, rainfall, 0.8, color=maroon, edgecolor=skyblue, linewidth=4)
# Set labels
ax.set_title('Average Precipitation in London', fontsize=18)
ax.set_xlabel('Month', fontsize=14)
ax.set_ylabel('Rainfall [mm]', fontsize=14)
# x-Axis details
ax.set_xticks(np.arange(len(london_rainfall)))
ax.set_xticklabels(london_rainfall.keys())
ax.tick_params(axis='x', length=0)
xlocs = np.arange(len(london_rainfall)) - 0.5
ax.set_xticks(xlocs, minor=True)
ax.set_xlim(-0.5, len(london_rainfall) - 0.5)
# Finished
plt.show()
plt.close()

7 Save Plot

Finally, use plt.savefig('name_of_plot.png') to save the plot to your computer.

⇦ Back