Confusion Matrix in Machine Learning: Everything You Need to Know by@balapriya

Confusion Matrix in Machine Learning: Everything You Need to Know

Read on Terminal Reader
Open TLDR
react to story with heart
react to story with light
react to story with boat
react to story with money
In Machine Learning, the problem of classification involves predicting the categorical class label to which the query data point belongs. The confusion matrix is a tabular representation of the classification model’s performance. This tutorial will help you understand the confusion matrix and the various metrics that you can calculate from it. We’ll start by explaining what classification is, the types of classification problems, and how to interpret the confusion Matrix for a binary classification problem. In binary classification, the class labels `1` and `0` are used. We'll focus on binary classification in this tutorial.
image
Bala Priya C HackerNoon profile picture

Bala Priya C

A self-taught programmer, and a technical writer who authors tutorials, how-to guides, and more to help developers

github social icon

In Machine Learning, the problem of classification involves predicting the categorical class label to which the query data point belongs. And the confusion matrix is a tabular representation of the classification model’s performance.


This tutorial will help you understand the confusion matrix and the various metrics that you can calculate from the confusion matrix.


We’ll start by explaining what classification is, the types of classification problems, and how to interpret the confusion matrix for a binary classification problem.


Let's get started.

Table of Contents

  1. What is Classification?

  2. Types of Classification

    2.1. Binary Classification

    2.2.  Multiclass Classification

  3. General Structure of the Confusion Matrix

  4. How to Calculate Evaluation Metrics from Confusion Matrix

    4.1. Accuracy

    4.2 Recall

    4.3 Precision

  5. High Precision vs High Recall - When to Choose What?

  6. Generating the Confusion Matrix in scikit-learn

  7. Generating the Classification Report in scikit-learn


    What is Classification?

    In essence, classification algorithms aim at answering the question:


    “Given labeled training data points, what’s the class label of a previously unseen test, or query data point?”


    A classification problem could be as simple as classifying a given image as that of a cat or a dog.

     Source: giphy.com

    Source: giphy.com


    Or it could be as complex as examining brain scans to detect the presence or absence of tumors.


    Source: giphy.com

    Source: giphy.com

Types of Classification

Binary Classification

In this tutorial, we’ll focus on the binary classification problem.  In binary classification, the class labels 1 and 0 are used.

Suppose you’re given a large dataset of student loans containing features such as the name of the university, tuition and employment details.

You’d like to predict whether or not a new student with a specific tuition fee and employment status will default on the student loan. Notice how you’re trying to answer the question “Will the student default on the loan?”—and the answer is either a ‘Yes’ or a  ‘No’.

You might as well think of other examples, say, identifying spam emails - the answers in this case are ‘Spam’ or ‘Not Spam’.


Source: giphy.com

Source: giphy.com


In these examples,

  • the answers ‘Yes’, ‘Spam’ indicate relevant classes, and in practice are encoded as class 1, and
  • the answers ‘No’ and ‘Not Spam’ are encoded as class 0.


Using disease diagnosis as another example, if the problem is to detect the presence of a disease: label 1 indicates that the patient has the disease; and label 0 indicates the absence of the disease.


This classification problem where the data points belong to one of the two classes is called binary classification. And we’ll build on binary classification in this tutorial.

Multiclass Classification

You can also have classification problems where you have more than two classes, called multiclass classification.


For instance, classifying an email as ‘Spam’ or ‘Not Spam’ is a binary classification problem, whereas, categorizing emails as ‘School’, ‘Work’ or ‘Personal’ is a multiclass classification problem.


Source: giphy.com

Source: giphy.com


Now that you’ve gained an understanding of the types of classification, let’s proceed to understand the confusion matrix.


General Structure of the Confusion Matrix

The general structure of the confusion matrix for binary classification is shown below:


Confusion Matrix for Binary Classification (Image by the author)

Confusion Matrix for Binary Classification (Image by the author)


Let’s now define a few terms:

  • True Positive (TP): When the actual label is 1, and the classifier also predicted the label to be 1
  • False Positive (FP): When the actual label is 0, but the classifier falsely predicted it to be 1
  • True Negative (TN): When the actual label is 0, and classifier also predicted as 0
  • False Negative (FN): When the actual label is 1, but the classifier predicted the label to be 0


Let’s now head over to the next section to understand the evaluation metrics for classification.


You’ll learn them by asking questions and following up with answers—and the answers explain what the metric signifies.

How to Calculate Evaluation Metrics from Confusion Matrix

Accuracy

Accuracy answers the question:


“How often is the model correct?”


The number of times the classifier correctly predicted class 1, plus the number of times it correctly predicted class 0.


Now, look up from the matrix above, it’s the count of True Positive (TP) + True Negative (TN). And the total number of predictions is the sum of counts in all 4 quadrants.


This this leads to the formula for accuracy as given below:


Accuracy = TP + TN/ (Total Predictions)

where, Total Predictions = TP + TN + FP + FN


At the outset accuracy may seem like a good metric for evaluation. However, it is not a reliable metric when you have an imbalance in the class labels.


Suppose you’re designing a model to predict if a person has a particular medical condition that is rare—say, it affects only 0.5% of the population.


So in a population of 1000 people, about 5 people will likely have the disease. You clearly have a class imbalance in this case! The majority class is class 0 indicating that the person doesn’t have that particular medical condition.


In this case, a naïve model that predicts the majority class all the time will be 99.5% accurate. However, such a model clearly isn't very helpful.


Can you see why this is the case? The confusion matrix for this example will look like this:

Confusion Matrix for 1000 predictions (Image by the author)

Confusion Matrix for 1000 predictions (Image by the author)


  • You’re making 1000 predictions. And for all of them, the predicted label is class 0.

  • And 995 of them are actually correct (True Negatives!)

  • And 5 of them are wrong.

  • The accuracy score still works out to 995/1000 = 0.995


To sum up, imbalanced class labels distort accuracy scores. And the model is projected to perform better than what is truly warranted.


Examples include problems like:

  • Credit card transactions that are potentially fraudulent
  • A medical condition that affects a very small fraction of the total population


If the percentage of the minority class is p%, a model that predicts the majority class all the time will have an accuracy score of 1 - p.


As you might have guessed by now, the error rate is 1-accuracy score.

Instead of saying “My model is correct 98% of the time”, if you’d like to say “My model is wrong 2% of the time”, then you’re talking error rates!


So it’s now time to learn about other metrics that are more useful in measuring a model’s performance.

Recall

Recall answers the question:


“When it actually is a positive case, how often is the model correct? Or, What fraction of the positive labels does the model predict correctly?”


In essence, it’s the number of relevant cases that have been found by the model.


Now, go back to the confusion matrix and look up the Actual row to identify which predictions correspond to an actually positive label—that is, class 1.


Calculating Recall (Image by the author)

Calculating Recall (Image by the author)


As you can see, it’s the TP + FN count.


And the number of times the model got it right is equal to the TP count. So here’s our formula for recall:


Recall = TP/ (TP + FN)


Our previous model for disease detection did not identify any positive cases—so the TP count = 0. And that leaves us with a recall score of 0. So the model has a recall score of 0 even though its accuracy score is 0.995.

Precision

Precision answers the following question:


“When the prediction is positive, how often is it correct?”


Once again, go back to the confusion matrix and look up under the Predicted column to identify which predictions correspond to a predicted positive label. And it’s the TP + FP count, as shown below:


Calculating Precision (Image by author)

Calculating Precision (Image by author)


Here’s our formula for precision:


Precision = TP/ (TP + FP)


In practice, you’ll often hear people talk about the Precision-Recall Trade-off.

This means you cannot maximize both precision and recall, and will have to choose one over the other—depending on the problem at hand.


Let’s discuss that in the next section.

High Precision vs High Recall - When to Choose What?

For the problem that you’re solving, ask yourself the question: Which is worse - a False Positive (FP) or a False Negative (FN)?


If you cannot have a False Negative (FN) – Maximize recall

If you cannot have a False Positive (FP) – Maximize precision


Let's revisit the previous examples of disease detection and spam detection.


In which of the above cases would you prefer a higher recall?


Well, you probably guessed it right. It’s in the case of disease detection that you cannot afford to have a False Negative—therefore, you’ll need a high recall.


Why?🤔

You would rather misclassify a patient as having the disease—which is a False Positive(FP). And you’ll follow up with additional medical examination, and be extra cautious—rather than misclassify someone with the disease as healthy. In the worst case it could cost the person's life.


📧 Let us now look at the example of spam detection. Here, False Positives(FP) can be dangerous.


  • Recall that in the problem of spam classification tagging an email as spam is said to be predicting a positive label.
  • A spam or two in your inbox does not cost much but what if an email from a recruiter was misclassified as spam? And you never cared to look at it?😟
  • You’d lose a potential job opportunity. And here’s where you should maximize precision.


Not detecting a spam email (False Negative) is not as impactful as predicting a recruiter’s email to be spam (False Positive). So remember to ask yourself the above questions, and choose accordingly.


And that concludes our discussion. And it’s time to write some code.⏳


Generating the Confusion Matrix in scikit-learn

Download the code used in this tutorial from my GitHub repo.


Now, let’s see how you can generate the confusion matrix in scikit-learn.

  • You’ll have ground truth labels y_true.
  • And you’ll have the predicted labels y_pred.


Here are the steps:

  1. Generate the arrays


    image


  2. Generate the confusion matrix


    image


To know more about the implementation of confusion matrix in sklearn, read the docs here.


Generating Classification Report in scikit-learn

▶️ Here’s how you can generate the classification report with metrics like accuracy, precision, recall, and F1-score in scikit-learn.


image


To summarize

In this tutorial, you’ve learned:

  • What is classification and its types,
  • The general structure of the confusion matrix,
  • How to calculate various metrics from the confusion matrix, and
  • When to choose precision over recall and vice-versa.


Congratulations on making it this far!🎉

📌Related Posts

If you liked this post, here are a few other posts you may enjoy reading.

▶️ Cryptographic Hash Functions in Blockchain [with Bash & Python Code]

▶️ Advanced Topic Modeling Tutorial: How to Use SVD & NMF in Python

▶️ Document-Term Matrix in NLP: Count and TF-IDF Scores Explained

▶️ Learn K-Means Clustering by Quantizing Color Images in Python

▶️ 9 Best Data Engineering Courses You Should Take in 2022


If you’re looking to get started with Machine Learning, I wish you the very best in your journey! Happy learning and coding.🎉

react to story with heart
react to story with light
react to story with boat
react to story with money

Related Stories

L O A D I N G
. . . comments & more!