⇦ Back

If you have a set of data points that look like they’re increasing steadily, it might be useful to fit a straight line to them in order to describe the general shape of the data:

The line that you need to fit in order to achieve this shape will be one that is described by a linear function, that is any function of the form:

\(y = mx + c\)

The important thing to notice is that a linear function can be fully defined with two constants.

This tutorial will use three methods for fitting linear functions, in increasing order of complexity of the Python command involved:

  1. SciPy’s linregress()
  2. NumPy’s polyfit()
  3. SciPy’s curve_fit()

…but first, we need some data to fit the curves to:

Example Data

For this tutorial, let’s create some fake data to use as an example. This should be a set of points whose y-values increase proportionally to their x-values (if they don’t do this our attempts to fit a linear curve to them won’t work well!), with some random noise thrown in to mimic real-world data:

import numpy as np

# Set a seed for the random number generator so we get the same random numbers each time
np.random.seed(20210713)

# Create fake x-data
x = np.arange(1, 11)
# Create fake y-data
m = 2
c = 11
y = m * x + c
y = y + np.random.normal(scale=np.sqrt(np.mean(y)), size=len(x))  # Add noise

# Print
_ = [print(f'({x[i]}, {y[i]:4.1f})') for i in np.arange(len(x))]
## (1,  3.8)
## (2,  9.0)
## (3, 28.9)
## (4, 22.7)
## (5, 26.2)
## (6, 15.8)
## (7, 26.4)
## (8, 29.4)
## (9, 40.5)
## (10, 32.7)

The random noise is being added with the random.normal() function from NumPy which draws random samples from a normal (Gaussian) distribution. Let’s take a look at what this example data looks like on a scatter plot:

import matplotlib.pyplot as plt

# Formatting options for plots
A = 6  # Want figure to be A6
plt.rc('figure', figsize=[46.82 * .5**(.5 * A), 33.11 * .5**(.5 * A)])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('text.latex', preamble=r'\usepackage{textgreek}')

# Create a plot
ax = plt.axes()
ax.scatter(x, y, c='gray', marker='o', edgecolors='k', s=18)
ax.set_title('Example Data')
ax.set_xlabel(r'Independent Variable, $x$')
ax.set_ylabel(r'Dependent Variable, $y$')
ax.set_ylim(0)
ax.set_xlim(0)
plt.show()

Method 1: linregress

From SciPy’s stats sub-package we can import the linregress() function which performs linear regression. This is purpose-built for fitting linear curves: it takes the x- and y-data as inputs and returns all the results related to the regression procedure:

from scipy import stats

# Linear regression model
gradient, intercept, r_value, p_value, slope_std_error = stats.linregress(x, y)

The results can be used as follows:

import numpy as np

# Line of best fit
predict_y = gradient * x + intercept
# Associated error
pred_error = y - predict_y
degrees_of_freedom = len(x) - 2
residual_std_error = np.sqrt(np.sum(pred_error**2) / degrees_of_freedom)

Let’s take a look at the line of best fit superimposed over our raw data with the true relationship (ie the function that we used to create our example data) added in as well:

import matplotlib.pyplot as plt

# Formatting options for plots
A = 6  # Want figure to be A6
plt.rc('figure', figsize=[46.82 * .5**(.5 * A), 33.11 * .5**(.5 * A)])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('text.latex', preamble=r'\usepackage{textgreek}')

# Create a plot
ax = plt.axes()
ax.scatter(x, y, c='gray', marker='o', edgecolors='k', s=18, label='Raw data')
xlim = np.array(ax.get_xlim())
xlim[0] = 0
ax.plot(xlim, 2 * xlim + 11, 'k--', label='True underlying relationship')
ax.plot(xlim, gradient * xlim + intercept, 'k', label='Line of best fit')
ax.set_title('Example Data')
ax.set_xlabel(r'Independent Variable, $x$')
ax.set_ylabel(r'Dependent Variable, $y$')
ax.set_xlim(xlim)
ax.set_ylim(0)
ax.legend(fontsize=8)
plt.show()

Method 2: polyfit

The polyfit() command from NumPy is used to fit a polynomial function to data. Of course, a linear function IS a polynomial function (it’s a polynomial of degree 1) so we can go right ahead and run polyfit() with the following three arguments:

  • x - the x-values of your data
  • y - the y-values of your data
  • 1 - the degree of the polynomial you want to fit

The function will return p, the polynomial coefficients of the fitted line. In our case, these are the values \(m\) and \(c\) from the equation \(y = mx + c\), as shown below:

import numpy as np

# Fit a polynomial of degree 1 (a linear function) to the data
p = np.polyfit(x, y, 1)

# Extract the parameters
m = p[0]
c = p[1]

print(f'The equation of the line of best fit is y = {m:4.2f}x + {c:4.2f}')
## The equation of the line of best fit is y = 2.93x + 7.41

Here it is drawn out:

import matplotlib.pyplot as plt

# Formatting options for plots
A = 6  # Want figure to be A6
plt.rc('figure', figsize=[46.82 * .5**(.5 * A), 33.11 * .5**(.5 * A)])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('text.latex', preamble=r'\usepackage{textgreek}')

# Create a plot
ax = plt.axes()
ax.scatter(x, y, c='gray', marker='o', edgecolors='k', s=18, label='Raw data')
xlim = np.array(ax.get_xlim())
xlim[0] = 0
ax.plot(xlim, 2 * xlim + 11, 'k--', label='True underlying relationship')
ax.plot(xlim, m * xlim + c, 'k', label='Line of best fit')
ax.set_title('Example Data')
ax.set_xlabel(r'Independent Variable, $x$')
ax.set_ylabel(r'Dependent Variable, $y$')
ax.set_xlim(xlim)
ax.set_ylim(0)
ax.legend(fontsize=8)
plt.show()

Method 3: curve_fit

From the SciPy package we can get the curve_fit() function. This is more general than polyfit() in that we can fit any type of function we like - not just polynomials - but it’s more complicated in that we sometimes need to provide an initial guess as to what the constants could be in order for it to work.

In order to fit the function \(y = mx + c\) to our data we need to define it as a lambda function (ie as an object rather than as a command) of a dummy variable \(t\). We can then use the curve_fit() function to fit this object to the x- and y-data. Note that the curve_fit() function needs to be imported from the scipy.optimize sub-package:

from scipy.optimize import curve_fit

# Fit the function m * t + c to x and y
popt, pcov = curve_fit(lambda t, m, c: m * t + c, x, y)

Note that we need to remove any values that are equal to zero from our y-data (and their corresponding x-values from the x-data) for this to work, although there aren’t any of these in this example data so it’s not relevant here

The first output, popt, is a list of the optimised values for the parameters which, in our case, are the constants \(m\) and \(c\):

m = popt[0]
c = popt[1]

Let’s see what this looks like:

import matplotlib.pyplot as plt

# Formatting options for plots
A = 6  # Want figure to be A6
plt.rc('figure', figsize=[46.82 * .5**(.5 * A), 33.11 * .5**(.5 * A)])
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rc('text.latex', preamble=r'\usepackage{textgreek}')

# Create a plot
ax = plt.axes()
ax.scatter(x, y, c='gray', marker='o', edgecolors='k', s=18, label='Raw data')
xlim = np.array(ax.get_xlim())
xlim[0] = 0
ax.plot(xlim, 2 * xlim + 11, 'k--', label='True underlying relationship')
ax.plot(xlim, m * xlim + c, 'k', label='Line of best fit')
ax.set_title('Example Data')
ax.set_xlabel(r'Independent Variable, $x$')
ax.set_ylabel(r'Dependent Variable, $y$')
ax.set_xlim(xlim)
ax.set_ylim(0)
ax.legend(fontsize=8)
plt.show()

Comparison of Methods

The lines of best fit created by the three methods all look the same on the graphs, but are they actually the same? Let’s compare their output parameters:

# Method 1
p = stats.linregress(x, y)
m1 = p[0]
c1 = p[1]

# Method 2
p = np.polyfit(x, y, 1)
m2 = p[0]
c2 = p[1]

# Method 3
popt, pcov = curve_fit(lambda t, m, c: m * t + c, x, y)
m3 = popt[0]
c3 = popt[1]

print(f'Values of m: {m1:5.3f}, {m2:5.3f}, {m3:5.3f}. Values of c: {c1:5.3f}, {c2:5.3f}, {c3:5.3f}')
## Values of m: 2.931, 2.931, 2.931. Values of c: 7.406, 7.406, 7.406

Yep, they all give the same results! For linear curve-fitting, you can use any one of them.

⇦ Back