Skip to content

Basic plotting with Matplotlib

Matplotlib is a powerful library for making plots in Python. It can be used to create basic plots, but also has the ability to create very complex plots. A variety of useful tutorials can be found on the Matplotlib webpage, but here we will cover some of the basics.

The plotting functions within Matplotlib are found within the pyplot submodule, which is often imported using the alias:

from matplotlib import pyplot as plt

Here we will assume pyplot has been imported in this way. We will also be using NumPy arrays to hold the data being plotted, although lists can also be used.

Basic line plot

A basic line plot can be produced using the plot function of pyplot. At its most basic plot takes a set of y-axis values:

from matplotlib import pyplot as plt
import numpy as np

# create a quadratic curve
data = np.arange(0, 10) ** 2
plt.plot(data)

# plt.show() will open the plot in its own window
plt.show()

Basic line plot

Note

The show function opens an interactive window showing the figure. This window allows you to zoom in on the figure and save the figure. By default, if using show the code, or terminal, will not continue until you close the figure window.

As seen above, the x-axis will just use the integer index values starting at 0.

You can control the x-axis values to use by passing them as the first argument to plot, e.g.,:

x = np.linspace(-10, 10, 100)
y = 3.5 - 2.3 * x + 0.5 * x ** 2  # a more complex quadratic

# plot the data
plt.plot(x, y)
plt.show()

Basic line plot 2

Line styles and colours

The above plots default to a solid line in a blue colour. However, both the line style and colour can be controlled. You can also control whether to show markers at each of the data points.

The line style can be set with the linestyle keyword argument (a shorthand of ls can also be used), with the following values:

  • "-" or "solid" (this is the default)
  • "--" or "dashed"
  • "-." or "dashdot"
  • ":" or "dotted"
  • "None", " " or "" for no line
# show the different line styles
linestyles = ["-", "--", "-.", ":"]

x = np.linspace(-10, 10, 100)
for i, ls in enumerate(linestyles):
    y = 3.5 - 2.3 * (x + i) + 0.5 * (x + i) ** 2
    plt.plot(x, y, linestyle=ls, label=ls)

plt.legend()
plt.show()

Show line styles

The above example has also shown how to plot multiple data sets on top of each other. It just requires running the plot command multiple times. The different lines will default to use different colours.

The line colour can be set using the color keyword argument (note the US spelling). There are a wide range of named colours that can be used, although there are a set of base colours for which only the first initial is required:

import matplotlib.colors as mcolors

print(list(mcolors.BASE_COLORS))
['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']

Base colors

An example using a few of these colours is:

colors = ["r", "g", "b", "k"]

x = np.linspace(-10, 10, 100)
for i, c in enumerate(colors):
    y = 3.5 - 2.3 * (x + i) + 0.5 * (x + i) ** 2
    plt.plot(x, y, color=c)

plt.show()

Demonstration of using colors

Note

There are colourblind friendly colour palettes available, for example within the seaborn package, e.g.,:

from matplotlib import pyplot as plt
import seaborn

# create 6 colours from the colorblind palette
cp = seaborn.color_palette('colorblind', 6)

# these colours will be given names "C0" through to "C5"
for i in range(6):
    c = f"C{i}"  # set colour name
    plt.axvline(i, color=c, label=c)  # create vertical lines to show off colours
plt.xlim([-1, 6])
plt.legend(loc="upper right")
plt.show()

Demo of colourblind friendly palette

Seaborn is provided within the base Anaconda environment.

Marker styles

In the above example just the line has been plotted, but markers can also be added for each data point. The marker style can be set with the marker keyword argument. The full range of marker styles are listed on the marker page, but here we will list a few:

  • "." - a point
  • "o" - a circle
  • "v" - a downwards pointing triangle
  • "*" - a star
  • "+" - a plus
  • "x" - a cross
  • "s" - a square

The marker size can be set with the markersize keyword argument, and whether the marker is filled or not can be set using the markerfacecolor keyword argument:

markers = ["o", "v", "*", "s"]
markersizes = [4, 6, 8, 10]
markerfacecolors = ["None", "b", "r", "g"]
colors = ["k", "b", "r", "g"]
linestyles = ["-", "None", "-", "None"]  # set to show lines for alternate cases

x = np.linspace(-10, 10, 25)
for i in range(len(markers)):
    y = 3.5 - 2.3 * (x + 2 * i) + 0.5 * (x + 2 * i) ** 2
    plt.plot(
        x,
        y,
        marker=markers[i],
        markersize=markersizes[i],
        markerfacecolor=markerfacecolors[i],
        color=colors[i],
        linestyle=linestyles[i]
    )

plt.show()

Demonstration of using markers

Axis labels

The above plots were missing important information. Plots should always have axis labels!

Labels can be added to the x- and y-axes using the xlabel and ylabel functions, e.g.,:

# some example data
position = [1.2, 5.6, 9.8, 17.9, 21.3, 24.3]
height = [4.5, 7.8, 10.3, 14.5, 12.2, 11.1]

plt.plot(position, height)
plt.xlabel("Position (m)")
plt.ylabel("Height (m)")
plt.show()

Demonstration of axes labels

The font and font size for the axes labels, and many other font effects, can be controlled with the fontfamily and fontsize keyword arguments, e.g.,

plt.plot(position, height)
plt.xlabel("Position (m)", fontfamily="Monospace", fontsize=14)
# use a different font size for y-axis as an example
plt.ylabel("Height (m)", fontfamily="Times New Roman", fontsize=20)
plt.show()

For some reason the default font size in almost all plotting programs and packages (including matplotlib) is far too small. When creating figures to include in a report or other document, you should always make sure that the text and axis labels, etc., are a similar size to the text in the main body of your report. It shouldn't be necessary to use a magnifying glass to read the axis labels!

Demonstration of axes label fonts

If you want to use mathematical text, or Greek lettering, in axes labels you can used LaTeX-like commands enclosed in dollar signs, e.g.,

x = np.linspace(-10, 10, 100)
y = x ** 2  # a quadratic

plt.plot(x, y)
# use LaTeX math in labels
plt.xlabel(r"$\eta$", fontsize=16)
plt.ylabel(r"$f(\eta) = \eta^2$", fontsize=16)
plt.show()

It is often helpful to use raw strings for axis labels that contain LaTeX, otherwise the backslash may need to be escaped.

Demonstration of axes label latex

Legends

If you have multiple data sets on a single plot it is useful to differentiate them with different line colours, line styles, and/or marker styles. By default Matplotlib will use different colours for multiple data sets, but as shown above you can control what line colours are used.

You can add labels to each data set that you plot using the label keyword argument to plot. These "labels" can then be used in a legend using the legend function.

rng = np.random.default_rng()

x = np.arange(10)

# create from data sets
y1 = rng.normal(size=len(x))  # noise
y2 = rng.normal(size=len(x)) + 3 * x  # noise and line
y3 = rng.normal(size=len(x)) + 1.5 * x ** 2  # noise and quadratic

# plot data with labels
plt.plot(x, y1, color="b", label="Data 1")
plt.plot(x, y2, color="r", label="Data 2")
plt.plot(x, y3, color="g", label="Data 3")

plt.xlabel(r"$x$")
plt.ylabel(r"$y$")

# add legend
plt.legend()
plt.show()

Demonstration of legend

By default the location of the legend will be set to try and avoid overlapping with most of the data. The legend location can be set explicitly to a particular place using the loc keyword argument and a location string, e.g.,:

  • "upper right"
  • "upper left"
  • "lower right"
  • "lower left"

Axis limits

Matplotlib will automatically try and determine the range of values shown in the x- and y-axes. However, you can manually set the axis ranges to whatever you require using the xlim and ylim functions in pyplot. These functions take in a list, or tuple, containing two values: the lower and upper ends of the range. For example:

x = np.linspace(-10, 10, 100)
y = 3.5 - 2.3 * x + 0.5 * x ** 2  # a more complex quadratic

# plot the data
plt.plot(x, y)

# zoom in on an x-range from -5 to 5
plt.xlim([-5, 5])

# zoom in on a y-range from -5 to 25
plt.ylim([-5, 25])

plt.xlabel(r"$x$")
plt.ylabel(r"$y$")

plt.show()

Demonstration of axis limits

Grid lines

Sometimes it is useful to add a background grid to a figure to aid visualisation. This can be added using the grid function in pyplot. By default the grid lines are applied to both the x- and y-axes, but this can be specified with the axis keyword argument, e.g., axis="x" to just add a grid on the x-axis.

x = np.linspace(-10, 10, 100)
y = 3.5 - 2.3 * x + 0.5 * x ** 2  # a more complex quadratic

# plot the data
plt.plot(x, y)

plt.xlabel(r"$x$")
plt.ylabel(r"$y$")

# turn on the grid
plt.grid()

plt.show()

Demonstration of grid

Logarithmic axes

If you have (positive) data that spans over many orders of magnitude it is often useful to plot the logarithm of the data. There are three pyplot functions that enable you to do this:

  • loglog - plot the base-10 logarithm of data on both the x- and y-axes;
  • semilogx - plot the x-axis on a logarithmic scale, but the y-axis on a linear scale;
  • semilogy - plot the y-axis on a logarithmic scale, but the x-axis on a linear scale.

For example:

x = np.logspace(-5, 5, 100)  # linearly spaced in base-10 log space
y = 2.5 * x ** 4.5

# plot the data in log-log space
plt.loglog(x, y)

plt.xlabel(r"$x$")
plt.ylabel(r"$y$")

# turn on the grid
plt.grid()

plt.show()

Demonstration of loglog plot

Basic scatter plot

The plot function will produce a line plot, but by setting the line style to be "None" and explicitly giving a marker style it can be used to produce a scatter plot, i.e., a plot of individual points, e.g.,

# produce some random points
x = np.random.randn(200)
y = np.random.randn(200)

# plot data points using circle markers
plt.plot(x, y, linestyle="None", marker="o")
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.show()

Demonstration of scatter plot

An alternative is to the use scatter function in pyplot. This adds the ability to add an additional dimension of information to the plot in the form the size and/or colour of the points, using the s and c keyword arguments, respectively. For example,

# produce some random points
x = np.random.randn(200)
y = np.random.randn(200)

# third dimension
z = 10 / np.sqrt(x ** 2 + y ** 2)

# plot data using scatter, with size and colour representing "z" data
plt.scatter(x, y, s=z, c=z)
plt.colorbar()
plt.xlabel(r"$x$")
plt.ylabel(r"$y$")
plt.show()

Demonstration of scatter plot

In the above example the colorbar function has been used to add a colour bar on the right hand side representing the z-axis values.

Note

You can plot with different plotting functions on the same plot, i.e., the plot function and the scatter function, e.g.,

x = np.random.randn(200)
y = np.random.randn(200)
z = 10 / np.sqrt(x ** 2 + y ** 2)

# plot data using scatter, with colour representing "z" data
plt.scatter(x, y, c=z)

# overplot a line plot
linex = np.linspace(-4, 4, 10)
liney = 1.5 + linex + 0.5
plt.plot(linex, liney, color="r")

plt.show()

Demonstration of scatter plot

Basic histogram

Sometimes you need to count the number of data points within a set ranges of values. This is called "binning", i.e., a count of the data in each "bin" or interval. A plot of the binned data is called a histogram and this can be made using the hist function in pyplot.

For example, if we had measured the speed of a set of particles we could look at the distribution of speeds using a histogram:

# create a set of Hydrogen atoms at ~room temperature
from scipy.stats import maxwell
m = 1.67e-27  # proton mass (kg)
kb = 1.38e-23  # Boltzmann constant (m^2 ks s^-2 K^-1)
T = 300  # temperature (K)

natoms = 100000  # number of atoms
speeds = maxwell.rvs(scale=np.sqrt(kb * T / m), size=natoms)

# plot distribution of atom speeds
plt.hist(speeds, bins=100)
plt.xlabel("Speed (m/s)")
plt.ylabel("Counts")
plt.show()

Demonstration of histogram

In the above example, the number of "bins" has been set to 100 using bins=100, and the hist function has created 100 equal size bins between the smallest and largest values (bins defaults to 10 bins). The bins keyword argument can also be set using an array of bin edge values, which do not have to be the same size.

To create a histogram that is normalised, i.e., the y-axis does not represent counts, but instead makes the area under the curve equal to one, you can set the density keyword argument to True.

An example showing two normalised histograms, one filled with colour, and the other unfilled, is shown below:

# create a set of Hydrogen atoms at ~room temperature
m = 1.67e-27  # proton mass (kg)
kb = 1.38e-23  # Boltzmann constant (m^2 ks s^-2 K^-1)
T = 300  # temperature (K)

natoms = 100000  # number of atoms
Hspeeds = maxwell.rvs(scale=np.sqrt(kb * T / m), size=natoms)

# create a set of Helium atoms at room temperature
Hespeeds = maxwell.rvs(scale=np.sqrt(kb * T / (4 * m)), size=natoms)

# plot probability density of speed distributions
plt.hist(Hspeeds, bins=50, density=True, color="b",
         histtype="stepfilled", alpha=0.5, label="Hydrogen")
plt.hist(Hespeeds, bins=50, density=True, color="r",
         histtype="step", label="Helium")
plt.xlabel("Speed (m/s)")
plt.ylabel("Probability density")
plt.legend()
plt.show()

Demonstration of histogram

The hist function can also be used to plot the cumulative distribution, e.g.,

# set bins edges
bins = np.linspace(
    min([Hspeeds.min(), Hespeeds.min()]),
    max([Hspeeds.max(), Hespeeds.max()]),
    100
)

plt.hist(Hspeeds, bins=bins, cumulative=True, density=True, color="b",
         histtype="step", label="Hydrogen")
plt.hist(Hespeeds, bins=bins, cumulative=True, density=True, color="r",
         histtype="step", label="Helium")
plt.xlabel("Speed (m/s)")
plt.ylabel("Cumulative probability")
plt.legend(loc="lower right")

# set some limits
plt.xlim([bins[0], bins[-1]])
plt.ylim([0, 1])

plt.show()

Demonstration of histogram

Saving figures

If using the show function to display figures the easiest way to save them is using the disk icon 💾 in the display window as shown below.

Display window

This will open up a file browser where you can type in the output file name. The file extension that you give determines what image format the plot is saved as. Some standard image formats are:

  • PNG - use the .png extension, e.g., "myimage.png"
  • JPEG - use the .jpg extension, e.g., "myimage.jpg"
  • PDF - use the .pdf extension, e.g., "myimage.pdf"
  • EPS - use the .eps extension, e.g., "myimage.eps"

For publication quality plots it is often good to save images as PDF files.

Displaying plots using show is useful when using a Python terminal or testing your code. But, generally within a script you do not want the plot to be displayed, you just want it created and saved. As mentioned above, by default when using show the code after that statement will not be executed until the image window has been closed, which means you have to manually intervene during code running.

Plots can be saved within the code using the savefig function in pyplot. For example, to save a plot to a PDF file, you could use:

position = [1.2, 5.6, 9.8, 17.9, 21.3, 24.3]
height = [4.5, 7.8, 10.3, 14.5, 12.2, 11.1]

plt.plot(position, height) 
plt.xlabel("Position (m)") 
plt.ylabel("Height (m)") 

# save the figure
plt.savefig("myplot.pdf")

If saving a plot to a PNG or JPEG format, you can set the resolution of the output image using the dpi ("dots per inch") keyword argument. Generally a dpi=150 is good enough for most purposes.

Tight layout

Often figures will be produced with excessive amounts of whitespace around the edges. Matplotlib can try and optimise the spacing to remove some of this excess by using the tight_layout function in pyplot:

position = [1.2, 5.6, 9.8, 17.9, 21.3, 24.3]
height = [4.5, 7.8, 10.3, 14.5, 12.2, 11.1]

plt.plot(position, height) 
plt.xlabel("Position (m)") 
plt.ylabel("Height (m)")

# use tight layout
plt.tight_layout()
plt.savefig("myimage.png", dpi=150)

The image below shows the plot produced with the above code both with and without tight_layout on the left and right, respectively.

Tight layout No tight layout

Customisation

The above examples have only scratched the surface of what Matplotlib can do. Here we will look at a few options that let you have a bit more control over the look of the figures.

Using axes and figures objects

In all the above examples we have just used pyplot functions to make a single figure. This figure has a standard default size and shape and we can only control one figure at a time.

A more useful way to create a figure is using the subplots function in pyplot.

Note

There is also the simpler figure function, but subplots can work in the same way as figure, but also allows multiple plots as discussed below.

You can create a single figure with subplots using:

fig, ax = plt.subplots()

When created the figure will not contain anything, but the function returns a Figure object (in the variable fig here) and an Axes object (in the variable ax here).

Data can be added to the figure through the Axes object which has methods for all of the pyplot plotting functions discussed above (e.g., plot, scatter and hist):

# add some data to the axes
x = np.linspace(-10, 10, 100)
y = 3.5 - 2.3 * x + 0.5 * x ** 2  # a more complex quadratic

# plot the data
ax.plot(x, y)

# add axis labels
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")

# set limits
ax.set_xlim([-7, 7])
ax.set_ylim([-5, 30])

# save the plot
fig.tight_layout()
fig.savefig("myplot.png", dpi=150)

Above it can be seen that some things have been done differently. To set the axis labels the set_xlabel and set_ylabel methods of the Axes class have had to be used rather than the pyplot.xlabel and pyplot.ylabel functions. The axes limits have also been set using set_xlim and set_ylim methods of the Axes class. Finally, the tight_layout and savefig methods of the Figure class have been used.

By having the figures as variables we can create multiple figures within the same code.

We can also make use of the figsize keyword argument to set the size of the figures, which takes a tuple containing the width and height of the figure (in inches). For example we could create two plots, with different aspect ratios, using:

# a narrow plot
fig1, ax1 = plt.subplots(figsize=(4, 10))

# plot something on this axes
ax1.hist(np.random.rand(1000))

# a wide plot
fig2, ax2 = plt.subplots(figsize=(10, 4))

ax2.hist(np.random.randn(1000), bins=20)

# save the figures
fig1.savefig("narrow.png")
fig2.savefig("wide.png")
Tall and narrow Short and wide

Multiple plots in a figure

Sometimes it is useful to plot related data on the same figure, but in a separate plot. Using subplots multiple plots can be added to the same figure. For example, we can create one plot above another, and have them share the same x-axis (i.e., the same range and positioning of x-axis label):

fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True)

# create some data
x = np.linspace(0.0, 2.0 * np.pi, 100)

y1 = np.sin(x)
y2 = np.cos(x)

# axs is now an array with the shape nrows x ncols
axs[0].plot(x, y1, color="blue")
axs[1].plot(x, y2, color="red")

# set axes labels
axs[1].set_xlabel(r"$\phi$")  # only set x-axis label on bottom plot
axs[0].set_ylabel(r"$\sin(\phi)$")
axs[1].set_ylabel(r"$\cos(\phi)$")

fig.tight_layout()
fig.savefig("trigfuncs.png")

Multiple plots

Similarly, side-by-side plots can share the same y-axis if required.

Note

More complex plot grids can be defined using the gridspec module.

Setting up default parameters

There are a wide range of default settings used within Matplotlib when creating and saving a plot. These can be found through the rcParams object:

from matplotlib import rcParams
# show rcParams (truncated here)
print(rcParams)
_internal.classic_mode: False
agg.path.chunksize: 0
animation.avconv_args: []
animation.avconv_path: avconv
animation.bitrate: -1
animation.codec: h264
animation.convert_args: []
animation.convert_path: convert
animation.embed_limit: 20.0
animation.ffmpeg_args: []
animation.ffmpeg_path: ffmpeg
animation.frame_format: png
animation.html: none
animation.html_args: []
animation.writer: ffmpeg
axes.autolimit_mode: data
axes.axisbelow: line
axes.edgecolor: black
axes.facecolor: white
axes.formatter.limits: [-5, 6]
...

There are multiple methods to manually adjust these defaults, but the one we will show is to directly edit the rcParams object. For example, we can change the default font and font size, and the default figure size, with:

from matplotlib import rcParams

rcParams["font.family"] = "serif"  # change to default to a serif font
rcParams["font.serif"] = "Times New Roman"  # change the default serif font to Times New Roman
rcParams["font.size"] = 14  # change the default font size
rcParams["figure.figsize"] = [9.7, 6]  # change the default figure size
rcParams["figure.autolayout"] = True  # automatically apply tight_layout

from matplotlib import pyplot as plt

fig, ax = plt.subplots()

x = np.linspace(-10, 10, 100)
y = 3.5 - 2.3 * x + 0.5 * x ** 2  # a more complex quadratic

# plot the data
ax.plot(x, y)

# add axis labels
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$y$")

fig.savefig("myfigure.png")

Configuration parameters

Note

Any changes you make to rcParams in a script or Python terminal session will only be applied to that script/session. If you start a new session, or run a new script, the original defaults will be reverted too. To keep the defaults across multiple runs/scripts you can define your own custom configuration file or style file.