import matplotlib.pyplot as plt
5 Visualization
The main reference for this Chapter is [1].
5.1 matplotlib.pyplot
matplotlib
is a modern and classic plot library. Its main features are inspired by MATLAB
. In this book we mostly use pyplot
package from matplotlib
. We use the following import convention:
5.1.1 matplotlib
interface
matplotlib
has two major application interfaces, or styles of using the library:
- An explicit
Axes
interface that uses methods on aFigure
orAxes
object to create other Artists, and build a visualization step by step. You may treat thisFigure
object as a canvas, andAxes
as plots on a canvas. There might be one or more plots on one canvas. This has also been called an object-oriented interface. - An implicit
pyplot
interface that keeps track of the lastFigure
andAxes
created, and adds Artists to the object it thinks the user wants.
Here is an example of an explicit interface.
= plt.figure()
fig = fig.subplots()
ax 1, 2, 3, 4], [0, 0.5, 1, 0.2]) ax.plot([
Here is an example of an implicit interface.
1, 2, 3, 4], [0, 0.5, 1, 0.2]) plt.plot([
If the plot is not shown, you may want to type plt.show()
to force the plot being rendered. However, to make plt.show()
work is related to switching matplotlib
backends, and is sometimes very complicated.
The purpose to explicitly use fig
and ax
is to have more control over the configurations. The first important configuration is subplots.
.subplot()
.subplots()
.add_subplot()
Please see the following examples.
Example 5.1
1, 2, 1)
plt.subplot(1, 2, 3], [0, 0.5, 0.2]) plt.plot([
Example 5.2
1, 2, 1)
plt.subplot(1, 2, 3], [0, 0.5, 0.2])
plt.plot([1, 2, 2)
plt.subplot(3, 2, 1], [0, 0.5, 0.2]) plt.plot([
Example 5.3
= plt.subplots(1, 2)
fig, axs 0].plot([1, 2, 3], [0, 0.5, 0.2])
axs[1].plot([3, 2, 1], [0, 0.5, 0.2]) axs[
Example 5.4
import numpy as np
= plt.figure()
fig = fig.add_subplot(2, 2, 1)
ax1 = fig.add_subplot(2, 2, 3)
ax2 = fig.add_subplot(1, 2, 2)
ax3
1, 2, 3], [0, 0.5, 0.2]) ax3.plot([
The auguments 2, 2, 1
means that we split the figure into a 2x2
grid and the axis ax1
is in the 1st position. The rest is understood in the same way.
Example 5.5 If you don’t explicitly initialize fig
and ax
, you may use plt.gcf()
and plt.gca()
to get the handles for further operations.
1, 2, 1)
plt.subplot(= plt.gca()
ax 1, 2, 3], [0, 0.5, 0.2])
ax.plot([
1, 2, 2)
plt.subplot(= plt.gca()
ax 3, 2, 1], [0, 0.5, 0.2]) ax.plot([
The purpose to explicitly use fig
and ax
is to have more control over the configurations. For example, when generate a figure
object, we may use figsize=(3, 3)
as an option to set the figure size to be 3x3
. dpi
is another commonly modified option.
= plt.figure(figsize=(2, 2), dpi=50)
fig 1, 2, 3], [0, 0.5, 0.2]) plt.plot([
If you would like to change this setting later, you may use the following command before plotting.
10, 10)
fig.set_size_inches(300)
fig.set_dpi(1, 2, 3], [0, 0.5, 0.2]) plt.plot([
You may use fig.savefig('filename.png')
to save the image into a file.
5.1.2 Downstream packages
There are multiple packages depending on matplotlib
to provide plotting. For example, you may directly plot from a Pandas DataFrame or a Pandas Series.
Example 5.6
import pandas as pd
import numpy as np
= pd.Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s s.plot()
<Axes: >
= pd.DataFrame(np.random.randn(10, 4).cumsum(0),
df =['A', 'B', 'C', 'D'],
columns=np.arange(0, 100, 10))
index df.plot()
<Axes: >
5.1.3 plotting
5.1.3.1 plt.plot()
This is the command for line plotting. You may use linestyle='--'
and color='g'
to control the line style and color. The style can be shortened as g--
.
Here is a list of commonly used linestyles and colors.
- line styles
solid
or-
dashed
or--
dashdot
or-.
dotted
or:
- marker styles
o
as circle markers+
as plusses^
as triangless
as squares
- colors
b
as blueg
as greenr
as redk
as blackw
as white
The input of plt.plot()
is two lists x
and y
. If there is only one list inputed, that one will be recognized as y
and the index of elements of y
will be used as the dafault x
.
Example 5.7
30).cumsum(), color='r', linestyle='--', marker='o') plt.plot(np.random.randn(
You may compare it with this Example for the purpose of seaborn
from next Section.
5.1.3.2 plt.bar()
and plt.barh()
The two commands make vertical and horizontal bar plots, respectively.
Example 5.8
import pandas as pd
= pd.Series(np.random.rand(16), index=list('abcdefghijklmnop'))
data
= plt.subplots(2, 1)
fig, axes 0].bar(x=data.index, height=data, color='k', alpha=0.7)
axes[1].barh(y=data.index, width=data, color='b', alpha=0.7) axes[
<BarContainer object of 16 artists>
We may also directly plot the bar plot from the Series.
= plt.subplots(2, 1)
fig, axes =axes[0], color='k', alpha=0.7)
data.plot.bar(ax=axes[1], color='b', alpha=0.7) data.plot.barh(ax
<Axes: >
With a DataFrame, bar plots group the values in each row together in a group in bars. This is easier if we directly plot from the DataFrame.
Example 5.9
= pd.DataFrame(np.random.rand(6, 4),
df =['one', 'two', 'three', 'four', 'five', 'six'],
index=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
columns df
Genus | A | B | C | D |
---|---|---|---|---|
one | 0.916812 | 0.762926 | 0.141764 | 0.943364 |
two | 0.213971 | 0.259643 | 0.745698 | 0.725222 |
three | 0.832894 | 0.840570 | 0.084964 | 0.473455 |
four | 0.244322 | 0.841298 | 0.042361 | 0.103337 |
five | 0.638863 | 0.263276 | 0.497620 | 0.010764 |
six | 0.129810 | 0.119783 | 0.900054 | 0.316613 |
df.plot.bar()
<Axes: >
=True, alpha=0.5) df.plot.barh(stacked
<Axes: >
5.1.3.3 plt.scatter()
Example 5.10
import numpy as np
= 100
N = 0.9 * np.random.rand(N, 2)
data = (20 * np.random.rand(N))**2
area = np.sqrt(area)
c 0], data[:, 1], s=area, marker='^', c=c) plt.scatter(data[:,
<matplotlib.collections.PathCollection at 0x1a99d98f640>
5.1.3.4 plt.hist()
Here are two plots with build-in statistics. The plot command will have statistics as outputs. To disable it we could send the outputs to a temporary variable _
.
Example 5.11
= 100, 15
mu, sigma = mu + sigma * np.random.randn(10000)
x = mu-30 + sigma*2 * np.random.randn(10000)
y = plt.hist(x, 50, density=True, facecolor='g', alpha=0.75)
_ = plt.hist(y, 50, density=True, facecolor='r', alpha=0.75) _
5.1.4 plt.boxplot()
Example 5.12
= np.random.rand(50) * 100
spread = np.ones(30) * 50
center = np.random.rand(10) * 100 + 100
flier_high = np.random.rand(10) * -100
flier_low = np.concatenate((spread, center, flier_high, flier_low)).reshape(50, 2)
data
= plt.boxplot(data, flierprops={'markerfacecolor': 'g', 'marker': 'D'}) _
5.1.5 Titles, labels and legends
- Titles
plt.title(label)
,plt.xlabel(label)
,plt.ylabel(label)
will set the title/xlabel/ylabel.ax.set_title(label)
,ax.set_xlabel(label)
,ax.set_ylabel(label)
will do the same thing.
- Labels
plt
methodsxlim()
,ylim()
,xticks()
,yticks()
,xticklabels()
,yticklabels()
- all the above with arguments
ax
methodsget_xlim()
,get_ylim()
, etc..set_xlim()
,set_ylim()
, etc..
- Legneds
- First add
label
option to each piece when plotting, and then addax.legends()
orplt.legends()
at the end to display the legends. - You may use
handles, labels = ax.get_legend_handles_labels()
to get the handles and labels of the legends, and modify them if necessary.
- First add
Example 5.13
import numpy as np
= plt.subplots(1, 1)
fig, ax 1000).cumsum(), 'k', label='one')
ax.plot(np.random.randn(1000).cumsum(), 'r--', label='two')
ax.plot(np.random.randn(1000).cumsum(), 'b.', label='three')
ax.plot(np.random.randn(
'Example')
ax.set_title('x')
ax.set_xlabel('y')
ax.set_ylabel(
-40, 0, 40])
ax.set_yticks(['good', 'bad', 'ugly'])
ax.set_yticklabels([
='best') ax.legend(loc
<matplotlib.legend.Legend at 0x1a99dcd0fa0>
5.1.6 Annotations
- The command to add simple annotations is
ax.text()
. The required auguments are the coordinates of the text and the text itself. You may add several options to modify the style. - If arrows are needed, we may use
ax.annotation()
. Here an arrow will be shown fromxytext
toxy
. The style of the arrow is controlled by the optionarrowprops
.
Example 5.14
= plt.subplots(figsize=(5, 5))
fig, ax 1000).cumsum(), 'k', label='one')
ax.plot(np.random.randn(500, 0, 'Hello world!', family='monospace', fontsize=15, c='r')
ax.text('test', xy=(400, 0), xytext=(400, -10), c='r',
ax.annotate(={'facecolor': 'black',
arrowprops'shrink': 0.05})
Text(400, -10, 'test')
5.1.7 Example
Example 5.15 The stock data can be downloaded from here.
from datetime import datetime
= plt.subplots()
fig, ax = pd.read_csv('assests/datasets/spx.csv', index_col=0, parse_dates=True)
data = data['SPX']
spx =ax, style='k-')
spx.plot(ax= [(datetime(2007, 10, 11), 'Peak of bull market'),
crisis_data 2008, 3, 12), 'Bear Stearns Fails'),
(datetime(2008, 9, 15), 'Lehman Bankruptcy')]
(datetime(for date, label in crisis_data:
=(date, spx.asof(date) + 75),
ax.annotate(label, xy=(date, spx.asof(date) + 225),
xytext=dict(facecolor='black', headwidth=4, width=2,
arrowprops=4),
headlength='left', verticalalignment='top')
horizontalalignment'1/1/2007', '1/1/2011'])
ax.set_xlim([600, 1800])
ax.set_ylim([= ax.set_title('Important dates in the 2008-2009 financial crisis') _
Example 5.16 Here is an example of arrows with different shapes. For more details please read the official document.
= plt.subplots()
fig, ax
= np.linspace(0, 20, 1000)
x
ax.plot(x, np.cos(x))'equal')
ax.axis(
'local maximum', xy=(6.28, 1), xytext=(10, 4),
ax.annotate(=dict(facecolor='black', shrink=0.05))
arrowprops
'local minimum', xy=(5 * np.pi, -1), xytext=(2, -6),
ax.annotate(=dict(arrowstyle="->",
arrowprops="angle3,angleA=0,angleB=-90",
connectionstyle='r')) color
Text(2, -6, 'local minimum')
5.2 seaborn
There are some new libraries built upon matplotlib
, and seaborn
is one of them. seaborn
is for statistical graphics.
seaborn
is used imported in the following way.
import seaborn as sns
seaborn
also modifies the default matplotlib
color schemes and plot styles to improve readability and aesthetics. Even if you do not use the seaborn
API, you may prefer to import seaborn
as a simple way to improve the visual aesthetics of general matplotlib
plots.
To apply sns
theme, run the following code.
sns.set_theme()
Let us directly run a few codes from the last section and compare the differences between them.
Example 5.17
30).cumsum(), color='r', linestyle='--', marker='o') plt.plot(np.random.randn(
Please compare the output of the same code with the previous example
5.2.1 Scatter plots with relplot()
The basic scatter plot method is scatterplot()
. It is wrapped in relplot()
as the default plotting method. So here we will mainly talk about relplot()
. It is named that way because it is designed to visualize many different statistical relationships.
The idea of relplot()
is to display points based on the variables x
and y
you choose, and assign different properties to alter the apperance of the points.
col
will create multiple plots based on the column you choose.hue
is for color encoding, based on the column you choose.size
will change the marker area, based on the column you choose.style
will change the marker symbol, based on the column you choose.
Example 5.18 Consider the following example. tips
is a DataFrame, which is shown below.
import seaborn as sns
= sns.load_dataset("tips")
tips tips
total_bill | tip | sex | smoker | day | time | size | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
2 | 21.01 | 3.50 | Male | No | Sun | Dinner | 3 |
3 | 23.68 | 3.31 | Male | No | Sun | Dinner | 2 |
4 | 24.59 | 3.61 | Female | No | Sun | Dinner | 4 |
... | ... | ... | ... | ... | ... | ... | ... |
239 | 29.03 | 5.92 | Male | No | Sat | Dinner | 3 |
240 | 27.18 | 2.00 | Female | Yes | Sat | Dinner | 2 |
241 | 22.67 | 2.00 | Male | Yes | Sat | Dinner | 2 |
242 | 17.82 | 1.75 | Male | No | Sat | Dinner | 2 |
243 | 18.78 | 3.00 | Female | No | Thur | Dinner | 2 |
244 rows × 7 columns
=tips,
sns.relplot(data="total_bill", y="tip", col="time",
x="smoker", style="smoker", size="size") hue
The default type of plots for relplot()
is scatter plots. However you may change it to line plot by setting kind='line'
.
Example 5.19
= sns.load_dataset("dots")
dots =dots, kind="line",
sns.relplot(data="time", y="firing_rate", col="align",
x="choice", size="coherence", style="choice",
hue=dict(sharex=False)) facet_kws
5.2.2 regplot()
This method is a combination between scatter plots and linear regression.
Example 5.20 We still use tips
as an example.
='total_bill', y='tip', data=tips) sns.regplot(x
<Axes: xlabel='total_bill', ylabel='tip'>
5.2.3 pairplot()
This is a way to display the pairwise relations among several variables.
Example 5.21 The following code shows the pairplots among all numeric data in tips
.
='kde', plot_kws={'alpha': 0.2}) sns.pairplot(tips, diag_kind
5.2.4 barplot
Example 5.22
='total_bill', y='day', data=tips, orient='h') sns.barplot(x
<Axes: xlabel='total_bill', ylabel='day'>
In the plot, there are several total_bill
during each day
. The value in the plot is the average of total_bill
in each day
, and the black line stands for the 95%
confidence interval.
='total_bill', y='day', hue='time', data=tips, orient='h') sns.barplot(x
<Axes: xlabel='total_bill', ylabel='day'>
In this plot, lunch and dinner are distinguished by colors.
5.2.5 Histogram
Example 5.23
= 100, 15
mu, sigma = mu + sigma * np.random.randn(10000)
x = mu-30 + sigma*2 * np.random.randn(10000)
y = pd.DataFrame(np.array([x,y]).T)
df =100, kde=True) sns.histplot(df, bins
<Axes: ylabel='Count'>
Please compare this plot with this Example
5.3 Exercises
Exercise 5.1 Please download the mtcars
file from here and read it as a DataFrame. Then create a scatter plot of the drat
and wt
variables from mtcars
and color the dots by the carb
variable.
Exercise 5.2 Please read the file as a DataFrame from here. This is the Dining satisfaction with quick service restaurants questionare data provided by Dr. Siri McDowall, supported by DART SEED grant.
- Please pick out all rating columns. Excluding
last.visit
,visit.again
andrecommend
, compute the mean of the rest and add it to the DataFrame as a new column. - Use a plot to show the relations among these four columns:
last.visit
,visit.again
,recommend
andmean
. - Look at the column
Profession
. KeepStudent
, and change everything else to beProfessional
, and add it as a new columnStatus
to the DataFrame. - Draw the histogram of
mean
with respect toStatus
. - Find the counts of each
recommend
rating for eachStatus
and draw the barplot. Do the same tolast.visit
/Status
andvisit.again
/Status
. - Explore the dataset and draw one plot.