Skip to content

Commit 8fb1eb7

Browse files
committed
Implementation of a regression tree in python
I've implemented a basic decision tree in python as an example of how they work. Although the class I've created only works on one dimensional data sets, the reader should be able to generalize it to higher dimensions should they need to.
1 parent 3ecb193 commit 8fb1eb7

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

machine_learning/decision_tree.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Implementation of a basic regression decision tree.
3+
Input data set: The input data set must be 1-dimensional with continuous labels.
4+
Output: The decision tree maps a real number input to a real number output.
5+
"""
6+
7+
import numpy as np
8+
9+
class Decision_Tree:
10+
def __init__(self, depth = 5, min_leaf_size = 5):
11+
self.depth = depth
12+
self.decision_boundary = 0
13+
self.left = None
14+
self.right = None
15+
self.min_leaf_size = min_leaf_size
16+
self.prediction = None
17+
18+
def mean_squared_error(self, labels, prediction):
19+
"""
20+
mean_squared_error:
21+
@param labels: a one dimensional numpy array
22+
@param prediction: a floating point value
23+
return value: mean_squared_error calculates the error if prediction is used to estimate the labels
24+
"""
25+
if labels.ndim != 1:
26+
print("Error: Input labels must be one dimensional")
27+
28+
return np.mean((labels - prediction) ** 2)
29+
30+
def train(self, X, y):
31+
"""
32+
train:
33+
@param X: a one dimensional numpy array
34+
@param y: a one dimensional numpy array.
35+
The contents of y are the labels for the corresponding X values
36+
37+
train does not have a return value
38+
"""
39+
40+
"""
41+
this section is to check that the inputs conform to our dimensionality constraints
42+
"""
43+
if X.ndim != 1:
44+
print("Error: Input data set must be one dimensional")
45+
return
46+
if len(X) != len(y):
47+
print("Error: X and y have different lengths")
48+
return
49+
if y.ndim != 1:
50+
print("Error: Data set labels must be one dimensional")
51+
52+
if len(X) < 2 * self.min_leaf_size:
53+
self.prediction = np.mean(y)
54+
55+
if self.depth == 1:
56+
self.prediction = np.mean(y)
57+
58+
best_split = 0
59+
min_error = self.mean_squared_error(X,np.mean(y)) * 2
60+
61+
62+
"""
63+
loop over all possible splits for the decision tree. find the best split.
64+
if no split exists that is less than 2 * error for the entire array
65+
then the data set is not split and the average for the entire array is used as the predictor
66+
"""
67+
for i in range(len(X)):
68+
if len(X[:i]) < self.min_leaf_size:
69+
continue
70+
elif len(X[i:]) < self.min_leaf_size:
71+
continue
72+
else:
73+
error_left = self.mean_squared_error(X[:i], np.mean(y[:i]))
74+
error_right = self.mean_squared_error(X[i:], np.mean(y[i:]))
75+
error = error_left + error_right
76+
if error < min_error:
77+
best_split = i
78+
min_error = error
79+
80+
if best_split != 0:
81+
left_X = X[:best_split]
82+
left_y = y[:best_split]
83+
right_X = X[best_split:]
84+
right_y = y[best_split:]
85+
86+
self.decision_boundary = X[best_split]
87+
self.left = Decision_Tree(depth = self.depth - 1, min_leaf_size = self.min_leaf_size)
88+
self.right = Decision_Tree(depth = self.depth - 1, min_leaf_size = self.min_leaf_size)
89+
self.left.train(left_X, left_y)
90+
self.right.train(right_X, right_y)
91+
else:
92+
self.prediction = np.mean(y)
93+
94+
return
95+
96+
def predict(self, x):
97+
"""
98+
predict:
99+
@param x: a floating point value to predict the label of
100+
the prediction function works by recursively calling the predict function
101+
of the appropriate subtrees based on the tree's decision boundary
102+
"""
103+
if self.prediction is not None:
104+
return self.prediction
105+
elif self.left or self.right is not None:
106+
if x >= self.decision_boundary:
107+
return self.right.predict(x)
108+
else:
109+
return self.left.predict(x)
110+
else:
111+
print("Error: Decision tree not yet trained")
112+
return None
113+
114+
def main():
115+
"""
116+
In this demonstration we're generating a sample data set from the sin function in numpy.
117+
We then train a decision tree on the data set and use the decision tree to predict the
118+
label of 10 different test values. Then the mean squared error over this test is displayed.
119+
"""
120+
X = np.arange(-1., 1., 0.005)
121+
y = np.sin(X)
122+
123+
tree = Decision_Tree(depth = 10, min_leaf_size = 10)
124+
tree.train(X,y)
125+
126+
test_cases = (np.random.rand(10) * 2) - 1
127+
predictions = np.array([tree.predict(x) for x in test_cases])
128+
avg_error = np.mean((predictions - test_cases) ** 2)
129+
130+
print("Test values: " + str(test_cases))
131+
print("Predictions: " + str(predictions))
132+
print("Average error: " + str(avg_error))
133+
134+
135+
if __name__ == '__main__':
136+
main()

0 commit comments

Comments
 (0)