CART Algorithms
Contents
3.2. CART Algorithms#
3.2.1. Ideas#
Consider a labeled dataset \(S\) with totally \(m\) elements. We use a feature \(k\) and a threshold \(t_k\) to split it into two subsets: \(S_l\) with \(m_l\) elements and \(S_r\) with \(m_r\) elements. Then the cost function of this split is
It is not hard to see that the more pure the two subsets are the lower the cost function is. Therefore our goal is find a split that can minimize the cost function.
(Split the Dataset)
Inputs Given a labeled dataset \(S=\{[features, label]\}\).
Outputs A best split \((k, t_k)\).
For each feature \(k\):
For each value \(t\) of the feature:
Split the dataset \(S\) into two subsets, one with \(k\leq t\) and one with \(k>t\).
Compute the cost function \(J(k,t)\).
Compare it with the current smallest cost. If this split has smaller cost, replace the current smallest cost and pair with \((k, t)\).
Return the pair \((k,t_k)\) that has the smallest cost function.
We then use this split algorithm recursively to get the decision tree.
(Classification and Regression Tree, CART)
Inputs Given a labeled dataset \(S=\{[features, label]\}\) and a maximal depth max_depth
.
Outputs A decision tree.
Starting from the original dataset \(S\). Set the working dataset \(G=S\).
Consider a dataset \(G\). If \(Gini(G)\neq0\), split \(G\) into \(G_l\) and \(G_r\) to minimize the cost function. Record the split pair \((k, t_k)\).
Now set the working dataset \(G=G_l\) and \(G=G_r\) respectively, and apply the above two steps to each of them.
Repeat the above steps, until
max_depth
is reached.
Here are the sample codes.
def split(G):
m = G.shape[0]
gmini = gini(G)
pair = None
if gini(G) != 0:
numOffeatures = G.shape[1] - 1
for k in range(numOffeatures):
for t in range(m):
Gl = G[G[:, k] <= G[t, k]]
Gr = G[G[:, k] > G[t, k]]
gl = gini(Gl)
gr = gini(Gr)
ml = Gl.shape[0]
mr = Gr.shape[0]
g = gl*ml/m + gr*mr/m
if g < gmini:
gmini = g
pair = (k, G[t, k])
Glm = Gl
Grm = Gr
res = {'split': True,
'pair': pair,
'sets': (Glm, Grm)}
else:
res = {'split': False,
'pair': pair,
'sets': G}
return res
For the purpose of counting labels, we also write a code to do so.
import pandas as pd
def countlabels(S):
y = S[:, -1].reshape(S.shape[0])
labelCount = dict(pd.Series(y).value_counts())
return labelCount