diff --git a/machine_learning/gaussian_naive_bayes.py b/machine_learning/gaussian_naive_bayes.py index 77e7326626c4..d54ecb21c119 100644 --- a/machine_learning/gaussian_naive_bayes.py +++ b/machine_learning/gaussian_naive_bayes.py @@ -1,7 +1,13 @@ # Gaussian Naive Bayes Example +import time + +import seaborn as sns # For plotting a heatmap of the confusion matrix. This way will not throw any warning. from matplotlib import pyplot as plt from sklearn.datasets import load_iris -from sklearn.metrics import plot_confusion_matrix +from sklearn.metrics import ( # The plot_confusion_matrix method will be deprecated in the developed versions of Python 3.10.x according to the warning that it throws when this code is run. + accuracy_score, + confusion_matrix, +) from sklearn.model_selection import train_test_split from sklearn.naive_bayes import GaussianNB @@ -25,20 +31,22 @@ def main(): # Gaussian Naive Bayes nb_model = GaussianNB() - nb_model.fit(x_train, y_train) - + model_fit = nb_model.fit(x_train, y_train) + y_pred = model_fit.predict(x_test) # Display Confusion Matrix - plot_confusion_matrix( - nb_model, - x_test, - y_test, - display_labels=iris["target_names"], - cmap="Blues", - normalize="true", - ) - plt.title("Normalized Confusion Matrix - IRIS Dataset") + Conf_Matrix = confusion_matrix(y_true=y_test, y_pred=y_pred) + sns.heatmap(data=Conf_Matrix, annot=True, cmap="Greys_r") plt.show() + # Printing the seen confusion matrix on the console + time.sleep(1.2) + print("The confusion matrix is:\n", Conf_Matrix) + + time.sleep(1.8) + # Declaring the overall accuracy of the model + final_accuracy = 100 * accuracy_score(y_true=y_test, y_pred=y_pred) + print(f"The final accuracy of the model is: {round(final_accuracy, 2)}%") + if __name__ == "__main__": main()