3.4. Decision Tree Project 2: make_moons dataset#

sklearn includes various random sample generators that can be used to build artificial datasets of controlled size and complexity. We are going to use make_moons in this section. More details can be found here.

make_moons generate 2d binary classification datasets that are challenging to certain algorithms (e.g. centroid-based clustering or linear classification), including optional Gaussian noise. make_moons produces two interleaving half circles. It is useful for visualization.

Let us explorer the dataset first.

from sklearn.datasets import make_moons
import matplotlib.pyplot as plt

X, y = make_moons(n_samples=10000, noise=0.4, random_state=42)
plt.scatter(x=X[:, 0], y=X[:, 1], c=y)
<matplotlib.collections.PathCollection at 0x1b466ceaec0>
../_images/dt2_2_1.png

Now we are applying sklearn.DecisionTreeClassifier to construct the decision tree. The steps are as follows.

  1. Split the dataset into training data and test data.

  2. Construct the pipeline. Since we won’t apply any transformers there for this problem, we may just use the classifier sklearn.DecisionTreeClassifier directly without really construct the pipeline object.

  3. Consider the hyperparameter space for grid search. For this problme we choose min_samples_split and max_leaf_nodes as the hyperparameters we need. We will let min_samples_split run through 2 to 5, and max_leaf_nodes run through 2 to 50. We will use grid_search_cv to find the best hyperparameter for our model. For cross-validation, the number of split is set to be 3 which means that we will run trainning 3 times for each pair of hyperparameters.

  4. Run grid_search_cv. Find the best hyperparameters and the best estimator. Test it on the test set to get the accuracy score.

# Step 1
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Step 3
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier
import numpy as np

params = {'min_samples_split': list(range(2, 5)),
          'max_leaf_nodes': list(range(2, 50))}
grid_search_cv = GridSearchCV(DecisionTreeClassifier(random_state=42), 
                              params, verbose=1, cv=3)
grid_search_cv.fit(X_train, y_train)
Fitting 3 folds for each of 144 candidates, totalling 432 fits
GridSearchCV(cv=3, estimator=DecisionTreeClassifier(random_state=42),
             param_grid={'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                                            13, 14, 15, 16, 17, 18, 19, 20, 21,
                                            22, 23, 24, 25, 26, 27, 28, 29, 30,
                                            31, ...],
                         'min_samples_split': [2, 3, 4]},
             verbose=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
# Step 4
from sklearn.metrics import accuracy_score

clf = grid_search_cv.best_estimator_
print(grid_search_cv.best_params_)
y_pred = clf.predict(X_test)
accuracy_score(y_pred, y_test)
{'max_leaf_nodes': 17, 'min_samples_split': 2}
0.8695

Now you can see that for this make_moons dataset, the best decision tree should have at most 17 leaf nodes and the minimum number of samples required to be at a leaft node is 2. The fitted decision tree can get 86.95% accuracy on the test set.

Now we can plot the decision tree and the decision surface.

from sklearn import tree
plt.figure(figsize=(15, 15), dpi=300)
tree.plot_tree(clf, filled=True)
[Text(0.5340909090909091, 0.9375, 'X[1] <= 0.296\ngini = 0.5\nsamples = 8000\nvalue = [3987, 4013]'),
 Text(0.25, 0.8125, 'X[0] <= -0.476\ngini = 0.367\nsamples = 4275\nvalue = [1036, 3239]'),
 Text(0.09090909090909091, 0.6875, 'X[0] <= -0.764\ngini = 0.183\nsamples = 472\nvalue = [424, 48]'),
 Text(0.045454545454545456, 0.5625, 'gini = 0.035\nsamples = 333\nvalue = [327, 6]'),
 Text(0.13636363636363635, 0.5625, 'X[1] <= 0.047\ngini = 0.422\nsamples = 139\nvalue = [97, 42]'),
 Text(0.09090909090909091, 0.4375, 'gini = 0.496\nsamples = 70\nvalue = [38, 32]'),
 Text(0.18181818181818182, 0.4375, 'gini = 0.248\nsamples = 69\nvalue = [59, 10]'),
 Text(0.4090909090909091, 0.6875, 'X[1] <= -0.062\ngini = 0.27\nsamples = 3803\nvalue = [612, 3191]'),
 Text(0.3181818181818182, 0.5625, 'X[1] <= -0.371\ngini = 0.147\nsamples = 2426\nvalue = [194, 2232]'),
 Text(0.2727272727272727, 0.4375, 'gini = 0.079\nsamples = 1336\nvalue = [55, 1281]'),
 Text(0.36363636363636365, 0.4375, 'gini = 0.223\nsamples = 1090\nvalue = [139, 951]'),
 Text(0.5, 0.5625, 'X[0] <= 1.508\ngini = 0.423\nsamples = 1377\nvalue = [418, 959]'),
 Text(0.45454545454545453, 0.4375, 'X[0] <= 0.503\ngini = 0.48\nsamples = 1013\nvalue = [404, 609]'),
 Text(0.36363636363636365, 0.3125, 'X[0] <= -0.162\ngini = 0.417\nsamples = 469\nvalue = [139, 330]'),
 Text(0.3181818181818182, 0.1875, 'gini = 0.5\nsamples = 120\nvalue = [61, 59]'),
 Text(0.4090909090909091, 0.1875, 'gini = 0.347\nsamples = 349\nvalue = [78, 271]'),
 Text(0.5454545454545454, 0.3125, 'X[0] <= 1.1\ngini = 0.5\nsamples = 544\nvalue = [265, 279]'),
 Text(0.5, 0.1875, 'X[1] <= 0.129\ngini = 0.49\nsamples = 339\nvalue = [193, 146]'),
 Text(0.45454545454545453, 0.0625, 'gini = 0.498\nsamples = 178\nvalue = [84, 94]'),
 Text(0.5454545454545454, 0.0625, 'gini = 0.437\nsamples = 161\nvalue = [109, 52]'),
 Text(0.5909090909090909, 0.1875, 'gini = 0.456\nsamples = 205\nvalue = [72, 133]'),
 Text(0.5454545454545454, 0.4375, 'gini = 0.074\nsamples = 364\nvalue = [14, 350]'),
 Text(0.8181818181818182, 0.8125, 'X[0] <= 1.452\ngini = 0.329\nsamples = 3725\nvalue = [2951, 774]'),
 Text(0.7272727272727273, 0.6875, 'X[1] <= 0.757\ngini = 0.232\nsamples = 3355\nvalue = [2905, 450]'),
 Text(0.6818181818181818, 0.5625, 'X[0] <= -0.588\ngini = 0.349\nsamples = 1629\nvalue = [1262, 367]'),
 Text(0.6363636363636364, 0.4375, 'gini = 0.07\nsamples = 384\nvalue = [370, 14]'),
 Text(0.7272727272727273, 0.4375, 'X[1] <= 0.439\ngini = 0.406\nsamples = 1245\nvalue = [892, 353]'),
 Text(0.6818181818181818, 0.3125, 'gini = 0.477\nsamples = 420\nvalue = [255, 165]'),
 Text(0.7727272727272727, 0.3125, 'gini = 0.352\nsamples = 825\nvalue = [637, 188]'),
 Text(0.7727272727272727, 0.5625, 'gini = 0.092\nsamples = 1726\nvalue = [1643, 83]'),
 Text(0.9090909090909091, 0.6875, 'X[0] <= 1.782\ngini = 0.218\nsamples = 370\nvalue = [46, 324]'),
 Text(0.8636363636363636, 0.5625, 'gini = 0.416\nsamples = 132\nvalue = [39, 93]'),
 Text(0.9545454545454546, 0.5625, 'gini = 0.057\nsamples = 238\nvalue = [7, 231]')]
../_images/dt2_8_1.png
from sklearn.inspection import DecisionBoundaryDisplay

DecisionBoundaryDisplay.from_estimator(
    clf,
    X,
    cmap=plt.cm.RdYlBu,
    response_method="predict"
)
plt.scatter(
    X[:, 0],
    X[:, 1],
    c=y,
    cmap='gray',
    edgecolor="black",
    s=15,
    alpha=.15)
<matplotlib.collections.PathCollection at 0x1b467aaf1c0>
../_images/dt2_9_1.png

Since it is not very clear what the boundary looks like, I will draw the decision surface individually below.

DecisionBoundaryDisplay.from_estimator(
    clf,
    X,
    cmap=plt.cm.RdYlBu,
    response_method="predict"
)
<sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay at 0x1b467b51cf0>
../_images/dt2_11_1.png