paint-brush
Chatbot Memory: Implement Your Own Algorithm From Scratchby@sebdtsci
155 reads

Chatbot Memory: Implement Your Own Algorithm From Scratch

by sebDtSciNovember 2nd, 2024
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

We present an approach to managing short-term memory in chatbots. We use a combination of storage and automatic summarization techniques to optimize conversational context. This approach not only improves the fluidity of interactions but also ensures contextual continuity during long dialogue sessions. The code uses PyTorch and Hugging Face's transformers to manage and compress the conversation history.
featured image - Chatbot Memory: Implement Your Own Algorithm From Scratch
sebDtSci HackerNoon profile picture

Introduction

There's one problem when implementing your own chatbot, and that's memory management during the conversation. Of course, you can use ready-made libraries such as 'Langchain', 'Ollama', etc ... But what if you want to implement your own algorithm from scratch?


We present here an approach to managing short-term memory in chatbots, using a combination of storage and automatic summarization techniques to optimize conversational context. The introduced method relies on a dynamic memory structure that limits data size while preserving essential information through intelligent summaries.


This approach not only improves the fluidity of interactions but also ensures contextual continuity during long dialogue sessions. Moreover, the use of asynchronous techniques guarantees that memory management operations do not interfere with the chatbot’s responsiveness.

Mathematical Modeling of Conversation Management

In this section, we mathematically formalize the management of conversation memory in the chatbot. The memory is structured as a list of pairs representing the exchanges between the user and the bot.


Don't worry, this part's pretty straightforward, just allowing us to model the basic mechanisms so we can transpose them into code.

Conversation Memory Structure

The conversation memory can be defined as an ordered list of pairs (u_i, d_i), where u_i represents the user’s input and d_i is the bot’s response for the i^th exchange. This list is denoted as C:


C = [(u1, d1), (u2, d2), ..., (un, dn)]


Where n is the total number of exchanges in the current history.

Memory Update

When a new exchange occurs, a new pair (u_{n+1}, d_{n+1}) is added to the memory. If the size of C exceeds a predefined maximum limit M_max, the oldest exchange is removed:


C = if |C| < M_max:
  C ∪ {(u_{n+1}, d_{n+1})} 
else:
  (C \\ {(u1, d1)}) ∪ {(u_{n+1}, d_{n+1})}

Word Count

To manage memory space and decide when compression is necessary, we calculate the total number of words W(C) in the memory:


W(C) = Σ (|ui| + |di|)


Where |ui| and |di| are the number of words in ui and di, respectively.

Memory Compression

When W(C) exceeds a threshold W_max, the memory is compressed to maintain relevant context. This compression is performed by a summarization model S, such as BART:


C_compressed = S(C)


Where C_compressed is the summarized version of the memory, reducing the total number of words while preserving the essence of past interactions.

Code Implementation for Chatbot Memory Management

In this section, we will examine a Python code example that illustrates memory management in a chatbot. The code uses PyTorch and Hugging Face's transformers to manage and compress the conversation history.

Environment Setup

We begin by checking if a GPU is available, which allows for faster processing if necessary.

import torch from transformers import pipeline import logging

if torch.cuda.is_available():
  device: int = 0 
else: 
  device: int = -1

MAX_MEMORY_SIZE: int = 2000


Defining the ChatbotMemory Class

The ChatbotMemory class manages the conversation history and performs update and compression operations. Every time update_memory is called, the memory text is counted and processed as needed.


class ChatbotMemory: 
def init(self, conv: list = []): 
  self.conversation_history = conv

def update_memory(self, user_input: str, bot_response: str) -> None:
    self.conversation_history.append(f"'user': {user_input}, 'bot': {bot_response}")

    if memory_counter(self.conversation_history) > 1000:
        self.conversation_history = compressed_memory(self.conversation_history)
        logging.info("Memory compressed.")

    if len(self.conversation_history) > MAX_MEMORY_SIZE:
        self.conversation_history.pop(0)
        logging.info("Memory reduced.")
    return 0

def get_memory(self):
    return self.conversation_history

Memory Compression and Counting

The _get_compressed_memory function uses the BART model to summarize the conversation history.

def _get_compressed_memory(sentence: str) -> str:
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device) 
  summary = summarizer(sentence, max_length=50, min_length=5, do_sample=False) 
  return summary[0]['summary_text']


The compressed_memory function applies the _get_compressed_memory function to each segment of the conversation history. This is optimized by processing in batches. This method is separate from _get_compressed_memory to allow for the introduction of new compression methods.

def compressed_memory(conv_hist: list) -> list:
  return [_get_compressed_memory(' '.join(conv_hist[i:i+5])) for i in range(0, len(conv_hist), 5)]


The memory_counter function counts the total number of words in the history. (Note that it might be interesting to perform this step with tokens instead of words.)

def memory_counter(conv_hist: list) -> int:
  st = ''.join(conv_hist) 
  return len(st.split())

Conclusion

This code establishes an efficient framework for memory management in a chatbot, using compression techniques to maintain relevant context and improve the overall system performance. Using summarization models like BART ensures that even when memory is compressed, the essential context is preserved.


I hope this little demonstration has helped you understand that it's quite simple to model common processes and implement them in your code. It's a basic but essential step that allows you to grasp how things work:


  • 1: I define all the action steps
  • 2: I model them mathematically
  • 3: code!

Code on GitHub