5  Visualization

The main reference for this Chapter is [1].

5.1 matplotlib.pyplot

matplotlib is a modern and classic plot library. Its main features are inspired by MATLAB. In this book we mostly use pyplot package from matplotlib. We use the following import convention:

import matplotlib.pyplot as plt

5.1.1 matplotlib interface

matplotlib has two major application interfaces, or styles of using the library:

  • An explicit Axes interface that uses methods on a Figure or Axes object to create other Artists, and build a visualization step by step. You may treat this Figure object as a canvas, and Axes as plots on a canvas. There might be one or more plots on one canvas. This has also been called an object-oriented interface.
  • An implicit pyplot interface that keeps track of the last Figure and Axes created, and adds Artists to the object it thinks the user wants.

Here is an example of an explicit interface.

fig = plt.figure()
ax = fig.subplots()
ax.plot([1, 2, 3, 4], [0, 0.5, 1, 0.2])

Here is an example of an implicit interface.

plt.plot([1, 2, 3, 4], [0, 0.5, 1, 0.2])

Note

If the plot is not shown, you may want to type plt.show() to force the plot being rendered. However, to make plt.show() work is related to switching matplotlib backends, and is sometimes very complicated.

The purpose to explicitly use fig and ax is to have more control over the configurations. The first important configuration is subplots.

  • .subplot()
  • .subplots()
  • .add_subplot()

Please see the following examples.

Example 5.1  

plt.subplot(1, 2, 1)
plt.plot([1, 2, 3], [0, 0.5, 0.2])

Example 5.2  

plt.subplot(1, 2, 1)
plt.plot([1, 2, 3], [0, 0.5, 0.2])
plt.subplot(1, 2, 2)
plt.plot([3, 2, 1], [0, 0.5, 0.2])

Example 5.3  

fig, axs = plt.subplots(1, 2)
axs[0].plot([1, 2, 3], [0, 0.5, 0.2])
axs[1].plot([3, 2, 1], [0, 0.5, 0.2])

Example 5.4  

import numpy as np
fig = plt.figure()
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 3)
ax3 = fig.add_subplot(1, 2, 2)

ax3.plot([1, 2, 3], [0, 0.5, 0.2])

The auguments 2, 2, 1 means that we split the figure into a 2x2 grid and the axis ax1 is in the 1st position. The rest is understood in the same way.

Example 5.5 If you don’t explicitly initialize fig and ax, you may use plt.gcf() and plt.gca() to get the handles for further operations.

plt.subplot(1, 2, 1)
ax = plt.gca()
ax.plot([1, 2, 3], [0, 0.5, 0.2])

plt.subplot(1, 2, 2)
ax = plt.gca()
ax.plot([3, 2, 1], [0, 0.5, 0.2])

The purpose to explicitly use fig and ax is to have more control over the configurations. For example, when generate a figure object, we may use figsize=(3, 3) as an option to set the figure size to be 3x3. dpi is another commonly modified option.

fig = plt.figure(figsize=(2, 2), dpi=50)
plt.plot([1, 2, 3], [0, 0.5, 0.2])

If you would like to change this setting later, you may use the following command before plotting.

fig.set_size_inches(10, 10)
fig.set_dpi(300)
plt.plot([1, 2, 3], [0, 0.5, 0.2])

You may use fig.savefig('filename.png') to save the image into a file.

5.1.2 Downstream packages

There are multiple packages depending on matplotlib to provide plotting. For example, you may directly plot from a Pandas DataFrame or a Pandas Series.

Example 5.6  

import pandas as pd
import numpy as np
s = pd.Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s.plot()
<Axes: >

df = pd.DataFrame(np.random.randn(10, 4).cumsum(0),
                  columns=['A', 'B', 'C', 'D'],
                  index=np.arange(0, 100, 10))
df.plot()
<Axes: >

5.1.3 plotting

5.1.3.1 plt.plot()

This is the command for line plotting. You may use linestyle='--' and color='g' to control the line style and color. The style can be shortened as g--.

Here is a list of commonly used linestyles and colors.

  • line styles
    • solid or -
    • dashed or --
    • dashdot or -.
    • dotted or :
  • marker styles
    • o as circle markers
    • + as plusses
    • ^ as triangles
    • s as squares
  • colors
    • b as blue
    • g as green
    • r as red
    • k as black
    • w as white

The input of plt.plot() is two lists x and y. If there is only one list inputed, that one will be recognized as y and the index of elements of y will be used as the dafault x.

Example 5.7  

plt.plot(np.random.randn(30).cumsum(), color='r', linestyle='--', marker='o')

You may compare it with this Example for the purpose of seaborn from next Section.

5.1.3.2 plt.bar() and plt.barh()

The two commands make vertical and horizontal bar plots, respectively.

Example 5.8  

import pandas as pd
data = pd.Series(np.random.rand(16), index=list('abcdefghijklmnop'))

fig, axes = plt.subplots(2, 1)
axes[0].bar(x=data.index, height=data, color='k', alpha=0.7)
axes[1].barh(y=data.index, width=data, color='b', alpha=0.7)
<BarContainer object of 16 artists>

We may also directly plot the bar plot from the Series.

fig, axes = plt.subplots(2, 1)
data.plot.bar(ax=axes[0], color='k', alpha=0.7)
data.plot.barh(ax=axes[1], color='b', alpha=0.7)
<Axes: >

With a DataFrame, bar plots group the values in each row together in a group in bars. This is easier if we directly plot from the DataFrame.

Example 5.9  

df = pd.DataFrame(np.random.rand(6, 4),
                  index=['one', 'two', 'three', 'four', 'five', 'six'],
                  columns=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
df
Genus A B C D
one 0.916812 0.762926 0.141764 0.943364
two 0.213971 0.259643 0.745698 0.725222
three 0.832894 0.840570 0.084964 0.473455
four 0.244322 0.841298 0.042361 0.103337
five 0.638863 0.263276 0.497620 0.010764
six 0.129810 0.119783 0.900054 0.316613
df.plot.bar()
<Axes: >

df.plot.barh(stacked=True, alpha=0.5)
<Axes: >

5.1.3.3 plt.scatter()

Example 5.10  

import numpy as np

N = 100
data = 0.9 * np.random.rand(N, 2)
area = (20 * np.random.rand(N))**2 
c = np.sqrt(area)
plt.scatter(data[:, 0], data[:, 1], s=area, marker='^', c=c)
<matplotlib.collections.PathCollection at 0x1a99d98f640>

5.1.3.4 plt.hist()

Here are two plots with build-in statistics. The plot command will have statistics as outputs. To disable it we could send the outputs to a temporary variable _.

Example 5.11  

mu, sigma = 100, 15
x = mu + sigma * np.random.randn(10000)
y = mu-30 + sigma*2 * np.random.randn(10000)
_ = plt.hist(x, 50, density=True, facecolor='g', alpha=0.75)
_ = plt.hist(y, 50, density=True, facecolor='r', alpha=0.75)

5.1.4 plt.boxplot()

Example 5.12  

spread = np.random.rand(50) * 100
center = np.ones(30) * 50
flier_high = np.random.rand(10) * 100 + 100
flier_low = np.random.rand(10) * -100
data = np.concatenate((spread, center, flier_high, flier_low)).reshape(50, 2)

_ = plt.boxplot(data, flierprops={'markerfacecolor': 'g', 'marker': 'D'})

5.1.5 Titles, labels and legends

  • Titles
    • plt.title(label), plt.xlabel(label), plt.ylabel(label) will set the title/xlabel/ylabel.
    • ax.set_title(label), ax.set_xlabel(label), ax.set_ylabel(label) will do the same thing.
  • Labels
    • plt methods
      • xlim(), ylim(), xticks(), yticks(), xticklabels(), yticklabels()
      • all the above with arguments
    • ax methods
      • get_xlim(), get_ylim(), etc..
      • set_xlim(), set_ylim(), etc..
  • Legneds
    • First add label option to each piece when plotting, and then add ax.legends() or plt.legends() at the end to display the legends.
    • You may use handles, labels = ax.get_legend_handles_labels() to get the handles and labels of the legends, and modify them if necessary.

Example 5.13  

import numpy as np
fig, ax = plt.subplots(1, 1)
ax.plot(np.random.randn(1000).cumsum(), 'k', label='one')
ax.plot(np.random.randn(1000).cumsum(), 'r--', label='two')
ax.plot(np.random.randn(1000).cumsum(), 'b.', label='three')

ax.set_title('Example')
ax.set_xlabel('x')
ax.set_ylabel('y')

ax.set_yticks([-40, 0, 40])
ax.set_yticklabels(['good', 'bad', 'ugly'])

ax.legend(loc='best')
<matplotlib.legend.Legend at 0x1a99dcd0fa0>

5.1.6 Annotations

  • The command to add simple annotations is ax.text(). The required auguments are the coordinates of the text and the text itself. You may add several options to modify the style.
  • If arrows are needed, we may use ax.annotation(). Here an arrow will be shown from xytext to xy. The style of the arrow is controlled by the option arrowprops.

Example 5.14  

fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(np.random.randn(1000).cumsum(), 'k', label='one')
ax.text(500, 0, 'Hello world!', family='monospace', fontsize=15, c='r')
ax.annotate('test', xy=(400, 0), xytext=(400, -10), c='r',
            arrowprops={'facecolor': 'black',
                        'shrink': 0.05})
Text(400, -10, 'test')

5.1.7 Example

Example 5.15 The stock data can be downloaded from here.

from datetime import datetime
fig, ax = plt.subplots()
data = pd.read_csv('assests/datasets/spx.csv', index_col=0, parse_dates=True)
spx = data['SPX']
spx.plot(ax=ax, style='k-')
crisis_data = [(datetime(2007, 10, 11), 'Peak of bull market'),
               (datetime(2008, 3, 12), 'Bear Stearns Fails'),
               (datetime(2008, 9, 15), 'Lehman Bankruptcy')]
for date, label in crisis_data:
    ax.annotate(label, xy=(date, spx.asof(date) + 75),
                xytext=(date, spx.asof(date) + 225),
                arrowprops=dict(facecolor='black', headwidth=4, width=2,
                                headlength=4),
                horizontalalignment='left', verticalalignment='top')
ax.set_xlim(['1/1/2007', '1/1/2011'])
ax.set_ylim([600, 1800])
_ = ax.set_title('Important dates in the 2008-2009 financial crisis')

Example 5.16 Here is an example of arrows with different shapes. For more details please read the official document.

fig, ax = plt.subplots()

x = np.linspace(0, 20, 1000)
ax.plot(x, np.cos(x))
ax.axis('equal')

ax.annotate('local maximum', xy=(6.28, 1), xytext=(10, 4),
            arrowprops=dict(facecolor='black', shrink=0.05))

ax.annotate('local minimum', xy=(5 * np.pi, -1), xytext=(2, -6),
            arrowprops=dict(arrowstyle="->",
                            connectionstyle="angle3,angleA=0,angleB=-90",
                            color='r'))
Text(2, -6, 'local minimum')

5.2 seaborn

There are some new libraries built upon matplotlib, and seaborn is one of them. seaborn is for statistical graphics.

seaborn is used imported in the following way.

import seaborn as sns

seaborn also modifies the default matplotlib color schemes and plot styles to improve readability and aesthetics. Even if you do not use the seaborn API, you may prefer to import seaborn as a simple way to improve the visual aesthetics of general matplotlib plots.

To apply sns theme, run the following code.

sns.set_theme()

Let us directly run a few codes from the last section and compare the differences between them.

Example 5.17  

plt.plot(np.random.randn(30).cumsum(), color='r', linestyle='--', marker='o')

Please compare the output of the same code with the previous example

5.2.1 Scatter plots with relplot()

The basic scatter plot method is scatterplot(). It is wrapped in relplot() as the default plotting method. So here we will mainly talk about relplot(). It is named that way because it is designed to visualize many different statistical relationships.

The idea of relplot() is to display points based on the variables x and y you choose, and assign different properties to alter the apperance of the points.

  • col will create multiple plots based on the column you choose.
  • hue is for color encoding, based on the column you choose.
  • size will change the marker area, based on the column you choose.
  • style will change the marker symbol, based on the column you choose.

Example 5.18 Consider the following example. tips is a DataFrame, which is shown below.

import seaborn as sns
tips = sns.load_dataset("tips")
tips
total_bill tip sex smoker day time size
0 16.99 1.01 Female No Sun Dinner 2
1 10.34 1.66 Male No Sun Dinner 3
2 21.01 3.50 Male No Sun Dinner 3
3 23.68 3.31 Male No Sun Dinner 2
4 24.59 3.61 Female No Sun Dinner 4
... ... ... ... ... ... ... ...
239 29.03 5.92 Male No Sat Dinner 3
240 27.18 2.00 Female Yes Sat Dinner 2
241 22.67 2.00 Male Yes Sat Dinner 2
242 17.82 1.75 Male No Sat Dinner 2
243 18.78 3.00 Female No Thur Dinner 2

244 rows × 7 columns

sns.relplot(data=tips,
            x="total_bill", y="tip", col="time",
            hue="smoker", style="smoker", size="size")

The default type of plots for relplot() is scatter plots. However you may change it to line plot by setting kind='line'.

Example 5.19  

dots = sns.load_dataset("dots")
sns.relplot(data=dots, kind="line",
            x="time", y="firing_rate", col="align",
            hue="choice", size="coherence", style="choice",
            facet_kws=dict(sharex=False))

5.2.2 regplot()

This method is a combination between scatter plots and linear regression.

Example 5.20 We still use tips as an example.

sns.regplot(x='total_bill', y='tip', data=tips)
<Axes: xlabel='total_bill', ylabel='tip'>

5.2.3 pairplot()

This is a way to display the pairwise relations among several variables.

Example 5.21 The following code shows the pairplots among all numeric data in tips.

sns.pairplot(tips, diag_kind='kde', plot_kws={'alpha': 0.2})

5.2.4 barplot

Example 5.22  

sns.barplot(x='total_bill', y='day', data=tips, orient='h')
<Axes: xlabel='total_bill', ylabel='day'>

In the plot, there are several total_bill during each day. The value in the plot is the average of total_bill in each day, and the black line stands for the 95% confidence interval.

sns.barplot(x='total_bill', y='day', hue='time', data=tips, orient='h')
<Axes: xlabel='total_bill', ylabel='day'>

In this plot, lunch and dinner are distinguished by colors.

5.2.5 Histogram

Example 5.23  

mu, sigma = 100, 15
x = mu + sigma * np.random.randn(10000)
y = mu-30 + sigma*2 * np.random.randn(10000)
df = pd.DataFrame(np.array([x,y]).T)
sns.histplot(df, bins=100, kde=True)
<Axes: ylabel='Count'>

Please compare this plot with this Example

5.3 Exercises

Exercise 5.1 Please download the mtcars file from here and read it as a DataFrame. Then create a scatter plot of the drat and wt variables from mtcars and color the dots by the carb variable.

Exercise 5.2 Please read the file as a DataFrame from here. This is the Dining satisfaction with quick service restaurants questionare data provided by Dr. Siri McDowall, supported by DART SEED grant.

  1. Please pick out all rating columns. Excluding last.visit, visit.again and recommend, compute the mean of the rest and add it to the DataFrame as a new column.
  2. Use a plot to show the relations among these four columns: last.visit, visit.again, recommend and mean.
  3. Look at the column Profession. Keep Student, and change everything else to be Professional, and add it as a new column Status to the DataFrame.
  4. Draw the histogram of mean with respect to Status.
  5. Find the counts of each recommend rating for each Status and draw the barplot. Do the same to last.visit/Status and visit.again/Status.
  6. Explore the dataset and draw one plot.