paint-brush
Embeddings for RAG - A Complete Overviewby@aibites
306 reads
306 reads

Embeddings for RAG - A Complete Overview

by Shrinivasan SankarNovember 30th, 2024
Read on Terminal Reader
Read this story w/o Javascript

Too Long; Didn't Read

Embedding is a crucial and fundamental step towards building a Retrieval Augmented Generation(RAG) pipeline. BERT and SBERT are state-of-the-art embedding models. Sententce transformers is the python library that implements both the models. This article dives deep into both theory and hands-on
featured image - Embeddings for RAG - A Complete Overview
Shrinivasan Sankar HackerNoon profile picture

This article starts with the transformers and looks at its shortcomings as an embedding model. It then gives an overview of BERT and deep dives into Sentence BERT (SBERT) which is the state-of-the-art in sentence embeddings for LLMs and RAG pipelines.

Visual Explanation

If you are a visual person like me and would like to watch a visual explanation, please check this video out:

Transformers

Transformers need no introduction. Though they were initially designed for language translation tasks, they are the driving horses behind almost all the LLMs today.


At a high level, they are composed of two blocks — the encoder and the decoder. The encoder block takes in the input and outputs a matrix representation. The decoder block takes in the output of the last encoder and produces the output. The encoder and decoder blocks can be composed of several layers, though the original transformer has 6 layers in each block.


All the layers are composed of multi-headed self-attention. However, the only difference between the encoder and the decoder is that the output of the encoder is fed to each layer of the decoder. In terms of the attention layers, the decoder attention layers are masked. So, the output at any position is influenced by the output at previous positions.


The encoder and decoder block are further composed of layer norm and feed-forward neural network layers.


Unlike earlier models like RNNs or LSTMs that processed tokens independently, the power of the transformers lies in their ability to capture the context of each token with respect to the entire sequence. Thus, it captures a lot of context compared to any previous architecture designed for language processing.

What's Wrong With Transformers?

Transformers are the most successful architectures that are driving the AI revolution today. So, I may be shown the door if I pinpoint limitations with it. However, as a matter of fact, to reduce computational overhead, its attention layers are designed only to attend to the past tokens. This is fine for most tasks. But may not be sufficient for a task like question-answering. Let's take the below example.


John came with Milo for the party. Milo had a lot of fun at the party. He is a beautiful, white cat with fur.


Let's say we ask the question, “Did Milo drink at the party with John?” Just based on the first 2 sentences in the above example, it's quite likely that the LLM will answer, “Given that Milo had lots of fun indicates that Milo drank at the party.”


However, a model trained with forward context would be aware of the 3rd sentence which is, “He is a beautiful, friendly cat”. And so, would reply, “Milo is a cat, and so is unlikely that he drank at the party.”


Though this is a hypothetical example, you get the idea. In a question-answering task, learning both forward and backward context becomes crucial. This is where the BERT model comes in.

BERT

BERT stands for Bidirectional Encoder Representations from Transformers. As the name suggests, it is based on Transformers, and it incorporates both forward and backward context. Though it was initially published for tasks like question answering and summarization, it has the potential to produce powerful embeddings due to its bidirectional nature.

BERT Model

BERT is nothing more than the transformer encoders stacked together in sequence. The only difference is that the BERT uses bidirectional self-attention, while the vanilla transformer uses constrained self-attention where every token can only attend to the context to its left.


Note: sequence vs sentence. Just a note on terminology to avoid confusion while dealing with the BERT model. A sentence is a series of words separatated by period. A sequence could be any number of sentences stacked together.


To understand BERT, let's take the example of question answering. As question-answering involves a minimum of two sentences, BERT is designed to accept a pair of sentences in the format <question-answer>. This leads to separator tokens like [CLS] passed at the beginning to indicate the beginning of the sequence. The [SEP] token is then used to separate the question and the answer.


So, a simple input now becomes, [CLS]<question>[SEP]<answer>[SEP] as shown in the below figure.

The two sentences A and B are passed through the WordPiece embedding model after including the [CLS] and [SEP] tokens. As we have two sentences, the model needs additional embeddings to differentiate them. This comes in the form of segment and position embeddings.


Segment embedding shown in green below indicates if the input tokens belong to sentence A or B. Then comes position embedding which indicates the position of each token in the sequence.

Figure taken from the BERT paper showing the input representation of the model.


All three embeddings are summed together and fed to the BERT model which is bidirectional as shown in the earlier figure. It captures not only the forward context but also the backward context before giving us the outputs for each token.

Pre-Training BERT

There are two ways in which the BERT model is pre-trained using two unsupervised tasks:

  • Masked language model(MLM). Here we mask some of the percent of the tokens in the sequence and let the model predict the masked tokens. It's also known as the cloze task. In practice, 15% of tokens are masked for this task.

  • Next Sentence Prediction(NSP). Here, we make the model predict the next sentence in the sequence. Whenever the sentence is the actual next one, we use the label IsNext and when it is not, we use the label NotNext.

    Pre-training of the BERT model with NSP and MLM tokens at the output.


As can be seen from the above figure from the paper, the first output token is used for the NSP task and the tokens in the middle which are masked are used for the MLM task.


As we are training at the token level, each input token produces an output token. As with any classification task, cross-entropy loss is used to train the model.

What's Wrong With BERT?

While BERT could be good at capturing both forward and backward context, it may not be best suited to find similarities between thousands of sentences. Let's consider the task of finding the most similar pair of sentences in a large collection of 10,000 sentences. In other words, we would like to “retrieve” the sentence that is most similar to sentence A out of 10,000 sentences.


To do this, we need to pair every possible combination of 2 sentences from 10,000. That would be n * (n — 1) / 2 = 4,999,500 pairs! Damn, that's quadratic complexity. It will take the BERT model 65 hours to create embeddings and solve for this comparison.


Simply said, the BERT model isn’t the best for similarity search. But retrieval and similarity search are at the heart of any RAG pipeline. The solution lies with SBERT.

SBERT — Sentence Level BERT

The limitation of BERT largely stems from its cross-encoder architecture where we feed two sentences together in sequence with a [SEP] token in between. If only each sentence were to be treated separately, we could pre-compute the embeddings and directly use them to compute similarly as and when needed. This is exactly the proposition of the Sentence BERT or SBERT in short.


SBERT introduces the Siamese network to the BERT architecture. The word means twin or closely related.

The meaning of Siamese taken from dictionary.com


So, in SBERT we have the same BERT network connected as “twins.” The model embeds the first sentence followed by the second instead of dealing with them sequentially.

Note: Its quite a common practice to draw 2 networks side-by-side to visualize siamese networks. But in practice, its a single network taking two different inputs.

SBERT Architecture

Below is a diagram that gives an overview of the SBERT architecture.

The Siamese network architecture with the classification objective for the loss. The outputs U and V from the two branches are concatenated along with their difference

.

First, we can notice that SBERT introduces a pooling layer soon after BERT. This reduces the dimension of BERT’s output to reduce computation. BERT generally produces outputs at 512 X 768 dimensions. The pooling layer reduces this to 1 X 768. The default pooling is mean though average and max pooling do work.


Next, let's look at the training approach where SBERT diverges from BERT.

Pre-Training

SBERT proposes three ways to train the model. Let's look at each of them.


Natural Language Inference (NLI) — Classification Objective

SBERT is fine-tuned on the Stanford Natural Language Inference (SNLI) and Multi-Genre NLI datasets for this. SNLI consists of 570K sentence pairs and MNLI has 430K. The pairs have a premise (P) and a hypothesis (H) leading to one of 3 labels:


  • Eltailment — premise suggests the hypothesis
  • Neutral — premise and hypothesis could be true but not necessarily related
  • Contradiction — premise and hypothesis contradict each other


Given the two sentences P and H, the SBERT model produces two outputs U and V. These are then concatenated as (U, V and |U — V|).


The concatenated output is used to train SBERT with the Classification Objective. This concatenated output is fed to a Feed Forward neural network with 3 class outputs (Eltailment, Neutral, and Contradiction). Softmax cross-entry is used for training similar to how we train for any other classification task.


Sentence Similarity — Regression objective

Instead of concatenating U and V, we directly compute a cosine similarity between the two vectors. Similar to any standard regression problem, we use a mean-squared error loss to train for regression. During inference, the same network can be directly used to compare any two sentences. SBERT gives a score as to how similar the two sentences are.


Triplet Similarity — Triplet Objective

The triplet similarity objective was first introduced in face recognition and has slowly been adapted to other areas of AI such as text and robotics.


Here 3 inputs are fed to SBERT instead of 2 — an anchor, a positive, and a negative. The dataset used for this should be chosen accordingly. To create it, we can choose any text data, and choose two consecutive sentences as positive entailment. Then choose a random sentence from a different paragraph a negative sample.


A triplet loss is then calculated by comparing how close the positive is to the anchor versus how close it is to the negative.

With that introduction to BERT and SBERT, let's do a quick hands-on to understand how we can get embeddings of any given sentence(s) using these models.

Hands-on SBERT

Even since its publication, the official library for SBERT which is sentence-transformerhas gained popularity and matured. It is good enough to be used in production use cases for RAG. So let's use it out of the box.


To get started, let's start with installation in a fresh new Python environment.

!pip install sentence-transformers


There are several variations of the SBERT model we can load from the library. Let's load the model for illustration.

from sentence_transformers import SentenceTransformer

model = SentenceTransformer('bert-base-nli-mean-tokens')


We can simply create a list of sentences and invoke the encode function of the model to create the embeddings. It's that simple!

sentences = [
        "The weather is lovely today.",
        "It's so sunny outside!",
        "He drove to the stadium.",
]
embeddings = model.encode(sentences)
print(embeddings.shape)


And we get can find the similarity scores between embeddings using the below 1 line:

similarities = model.similarity(embeddings, embeddings)
print(similarities)


Note that the similarity between the same sentence is 1 as expected:

tensor([[1.0000, 0.6660, 0.1046],
        [0.6660, 1.0000, 0.1411],
        [0.1046, 0.1411, 1.0000]])

Conclusion

Embedding is a crucial and fundamental step to getting the RAG pipeline working at its best. Hope that was useful and opened your eyes as to what’s going on under the hood whenever we use the sentence transformers out of the box.


Stay tuned for upcoming articles on RAG and its inner workings coupled with hands-on tutorials too.