1
+ """
2
+ Implementation of a basic regression decision tree.
3
+ Input data set: The input data set must be 1-dimensional with continuous labels.
4
+ Output: The decision tree maps a real number input to a real number output.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ class Decision_Tree :
10
+ def __init__ (self , depth = 5 , min_leaf_size = 5 ):
11
+ self .depth = depth
12
+ self .decision_boundary = 0
13
+ self .left = None
14
+ self .right = None
15
+ self .min_leaf_size = min_leaf_size
16
+ self .prediction = None
17
+
18
+ def mean_squared_error (self , labels , prediction ):
19
+ """
20
+ mean_squared_error:
21
+ @param labels: a one dimensional numpy array
22
+ @param prediction: a floating point value
23
+ return value: mean_squared_error calculates the error if prediction is used to estimate the labels
24
+ """
25
+ if labels .ndim != 1 :
26
+ print ("Error: Input labels must be one dimensional" )
27
+
28
+ return np .mean ((labels - prediction ) ** 2 )
29
+
30
+ def train (self , X , y ):
31
+ """
32
+ train:
33
+ @param X: a one dimensional numpy array
34
+ @param y: a one dimensional numpy array.
35
+ The contents of y are the labels for the corresponding X values
36
+
37
+ train does not have a return value
38
+ """
39
+
40
+ """
41
+ this section is to check that the inputs conform to our dimensionality constraints
42
+ """
43
+ if X .ndim != 1 :
44
+ print ("Error: Input data set must be one dimensional" )
45
+ return
46
+ if len (X ) != len (y ):
47
+ print ("Error: X and y have different lengths" )
48
+ return
49
+ if y .ndim != 1 :
50
+ print ("Error: Data set labels must be one dimensional" )
51
+ return
52
+
53
+ if len (X ) < 2 * self .min_leaf_size :
54
+ self .prediction = np .mean (y )
55
+ return
56
+
57
+ if self .depth == 1 :
58
+ self .prediction = np .mean (y )
59
+ return
60
+
61
+ best_split = 0
62
+ min_error = self .mean_squared_error (X ,np .mean (y )) * 2
63
+
64
+
65
+ """
66
+ loop over all possible splits for the decision tree. find the best split.
67
+ if no split exists that is less than 2 * error for the entire array
68
+ then the data set is not split and the average for the entire array is used as the predictor
69
+ """
70
+ for i in range (len (X )):
71
+ if len (X [:i ]) < self .min_leaf_size :
72
+ continue
73
+ elif len (X [i :]) < self .min_leaf_size :
74
+ continue
75
+ else :
76
+ error_left = self .mean_squared_error (X [:i ], np .mean (y [:i ]))
77
+ error_right = self .mean_squared_error (X [i :], np .mean (y [i :]))
78
+ error = error_left + error_right
79
+ if error < min_error :
80
+ best_split = i
81
+ min_error = error
82
+
83
+ if best_split != 0 :
84
+ left_X = X [:best_split ]
85
+ left_y = y [:best_split ]
86
+ right_X = X [best_split :]
87
+ right_y = y [best_split :]
88
+
89
+ self .decision_boundary = X [best_split ]
90
+ self .left = Decision_Tree (depth = self .depth - 1 , min_leaf_size = self .min_leaf_size )
91
+ self .right = Decision_Tree (depth = self .depth - 1 , min_leaf_size = self .min_leaf_size )
92
+ self .left .train (left_X , left_y )
93
+ self .right .train (right_X , right_y )
94
+ else :
95
+ self .prediction = np .mean (y )
96
+
97
+ return
98
+
99
+ def predict (self , x ):
100
+ """
101
+ predict:
102
+ @param x: a floating point value to predict the label of
103
+ the prediction function works by recursively calling the predict function
104
+ of the appropriate subtrees based on the tree's decision boundary
105
+ """
106
+ if self .prediction is not None :
107
+ return self .prediction
108
+ elif self .left or self .right is not None :
109
+ if x >= self .decision_boundary :
110
+ return self .right .predict (x )
111
+ else :
112
+ return self .left .predict (x )
113
+ else :
114
+ print ("Error: Decision tree not yet trained" )
115
+ return None
116
+
117
+ def main ():
118
+ """
119
+ In this demonstration we're generating a sample data set from the sin function in numpy.
120
+ We then train a decision tree on the data set and use the decision tree to predict the
121
+ label of 10 different test values. Then the mean squared error over this test is displayed.
122
+ """
123
+ X = np .arange (- 1. , 1. , 0.005 )
124
+ y = np .sin (X )
125
+
126
+ tree = Decision_Tree (depth = 10 , min_leaf_size = 10 )
127
+ tree .train (X ,y )
128
+
129
+ test_cases = (np .random .rand (10 ) * 2 ) - 1
130
+ predictions = np .array ([tree .predict (x ) for x in test_cases ])
131
+ avg_error = np .mean ((predictions - test_cases ) ** 2 )
132
+
133
+ print ("Test values: " + str (test_cases ))
134
+ print ("Predictions: " + str (predictions ))
135
+ print ("Average error: " + str (avg_error ))
136
+
137
+
138
+ if __name__ == '__main__' :
139
+ main ()
0 commit comments