1
+ """
2
+ Implementation of a basic regression decision tree.
3
+ Input data set: The input data set must be 1-dimensional with continuous labels.
4
+ Output: The decision tree maps a real number input to a real number output.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ class Decision_Tree :
10
+ def __init__ (self , depth = 5 , min_leaf_size = 5 ):
11
+ self .depth = depth
12
+ self .decision_boundary = 0
13
+ self .left = None
14
+ self .right = None
15
+ self .min_leaf_size = min_leaf_size
16
+ self .prediction = None
17
+
18
+ def mean_squared_error (self , labels , prediction ):
19
+ """
20
+ mean_squared_error:
21
+ @param labels: a one dimensional numpy array
22
+ @param prediction: a floating point value
23
+ return value: mean_squared_error calculates the error if prediction is used to estimate the labels
24
+ """
25
+ if labels .ndim != 1 :
26
+ print ("Error: Input labels must be one dimensional" )
27
+
28
+ return np .mean ((labels - prediction ) ** 2 )
29
+
30
+ def train (self , X , y ):
31
+ """
32
+ train:
33
+ @param X: a one dimensional numpy array
34
+ @param y: a one dimensional numpy array.
35
+ The contents of y are the labels for the corresponding X values
36
+
37
+ train does not have a return value
38
+ """
39
+
40
+ """
41
+ this section is to check that the inputs conform to our dimensionality constraints
42
+ """
43
+ if X .ndim != 1 :
44
+ print ("Error: Input data set must be one dimensional" )
45
+ return
46
+ if len (X ) != len (y ):
47
+ print ("Error: X and y have different lengths" )
48
+ return
49
+ if y .ndim != 1 :
50
+ print ("Error: Data set labels must be one dimensional" )
51
+
52
+ if len (X ) < 2 * self .min_leaf_size :
53
+ self .prediction = np .mean (y )
54
+
55
+ if self .depth == 1 :
56
+ self .prediction = np .mean (y )
57
+
58
+ best_split = 0
59
+ min_error = self .mean_squared_error (X ,np .mean (y )) * 2
60
+
61
+
62
+ """
63
+ loop over all possible splits for the decision tree. find the best split.
64
+ if no split exists that is less than 2 * error for the entire array
65
+ then the data set is not split and the average for the entire array is used as the predictor
66
+ """
67
+ for i in range (len (X )):
68
+ if len (X [:i ]) < self .min_leaf_size :
69
+ continue
70
+ elif len (X [i :]) < self .min_leaf_size :
71
+ continue
72
+ else :
73
+ error_left = self .mean_squared_error (X [:i ], np .mean (y [:i ]))
74
+ error_right = self .mean_squared_error (X [i :], np .mean (y [i :]))
75
+ error = error_left + error_right
76
+ if error < min_error :
77
+ best_split = i
78
+ min_error = error
79
+
80
+ if best_split != 0 :
81
+ left_X = X [:best_split ]
82
+ left_y = y [:best_split ]
83
+ right_X = X [best_split :]
84
+ right_y = y [best_split :]
85
+
86
+ self .decision_boundary = X [best_split ]
87
+ self .left = Decision_Tree (depth = self .depth - 1 , min_leaf_size = self .min_leaf_size )
88
+ self .right = Decision_Tree (depth = self .depth - 1 , min_leaf_size = self .min_leaf_size )
89
+ self .left .train (left_X , left_y )
90
+ self .right .train (right_X , right_y )
91
+ else :
92
+ self .prediction = np .mean (y )
93
+
94
+ return
95
+
96
+ def predict (self , x ):
97
+ """
98
+ predict:
99
+ @param x: a floating point value to predict the label of
100
+ the prediction function works by recursively calling the predict function
101
+ of the appropriate subtrees based on the tree's decision boundary
102
+ """
103
+ if self .prediction is not None :
104
+ return self .prediction
105
+ elif self .left or self .right is not None :
106
+ if x >= self .decision_boundary :
107
+ return self .right .predict (x )
108
+ else :
109
+ return self .left .predict (x )
110
+ else :
111
+ print ("Error: Decision tree not yet trained" )
112
+ return None
113
+
114
+ def main ():
115
+ """
116
+ In this demonstration we're generating a sample data set from the sin function in numpy.
117
+ We then train a decision tree on the data set and use the decision tree to predict the
118
+ label of 10 different test values. Then the mean squared error over this test is displayed.
119
+ """
120
+ X = np .arange (- 1. , 1. , 0.005 )
121
+ y = np .sin (X )
122
+
123
+ tree = Decision_Tree (depth = 10 , min_leaf_size = 10 )
124
+ tree .train (X ,y )
125
+
126
+ test_cases = (np .random .rand (10 ) * 2 ) - 1
127
+ predictions = np .array ([tree .predict (x ) for x in test_cases ])
128
+ avg_error = np .mean ((predictions - test_cases ) ** 2 )
129
+
130
+ print ("Test values: " + str (test_cases ))
131
+ print ("Predictions: " + str (predictions ))
132
+ print ("Average error: " + str (avg_error ))
133
+
134
+
135
+ if __name__ == '__main__' :
136
+ main ()
0 commit comments