Skip to content

Commit 4590363

Browse files
hrishi84cclauss
authored andcommitted
Added Pytests for Decission Tree mean_squared_error method (TheAlgorithms#1374)
* Added Pytests for Decission Tree Modified the mean_squared_error to be a static method Created the Test_Decision_Tree class Consists of two methods 1. helper_mean_squared_error_test: This method calculates the mean squared error manually without using numpy. Instead a for loop is used for the same. 2. test_one_mean_squared_error: This method considers a simple test case and compares the results by the helper function and the original mean_squared_error method of Decision_Tree class. This is done using asert keyword. Execution: PyTest installation pip3 install pytest OR pip install pytest Test function execution pytest decision_tree.py * Modified the pytests to be compatible with the doctest Added 2 doctest in the mean_squared_error method For its verification a static method helper_mean_squared_error(labels, prediction) is used It uses a for loop to calculate the error instead of the numpy inbuilt methods Execution ``` pytest .\decision_tree.py --doctest-modules ```
1 parent 179284a commit 4590363

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

machine_learning/decision_tree.py

+32
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ def mean_squared_error(self, labels, prediction):
2121
@param labels: a one dimensional numpy array
2222
@param prediction: a floating point value
2323
return value: mean_squared_error calculates the error if prediction is used to estimate the labels
24+
>>> tester = Decision_Tree()
25+
>>> test_labels = np.array([1,2,3,4,5,6,7,8,9,10])
26+
>>> test_prediction = np.float(6)
27+
>>> assert tester.mean_squared_error(test_labels, test_prediction) == Test_Decision_Tree.helper_mean_squared_error_test(test_labels, test_prediction)
28+
>>> test_labels = np.array([1,2,3])
29+
>>> test_prediction = np.float(2)
30+
>>> assert tester.mean_squared_error(test_labels, test_prediction) == Test_Decision_Tree.helper_mean_squared_error_test(test_labels, test_prediction)
31+
2432
"""
2533
if labels.ndim != 1:
2634
print("Error: Input labels must be one dimensional")
@@ -117,6 +125,27 @@ def predict(self, x):
117125
print("Error: Decision tree not yet trained")
118126
return None
119127

128+
class Test_Decision_Tree:
129+
"""Decision Tres test class
130+
"""
131+
132+
@staticmethod
133+
def helper_mean_squared_error_test(labels, prediction):
134+
"""
135+
helper_mean_squared_error_test:
136+
@param labels: a one dimensional numpy array
137+
@param prediction: a floating point value
138+
return value: helper_mean_squared_error_test calculates the mean squared error
139+
"""
140+
squared_error_sum = np.float(0)
141+
for label in labels:
142+
squared_error_sum += ((label-prediction) ** 2)
143+
144+
return np.float(squared_error_sum/labels.size)
145+
146+
147+
148+
120149

121150
def main():
122151
"""
@@ -141,3 +170,6 @@ def main():
141170

142171
if __name__ == "__main__":
143172
main()
173+
import doctest
174+
175+
doctest.testmod(name="mean_squarred_error", verbose=True)

0 commit comments

Comments
 (0)