There's one problem when implementing your own chatbot, and that's memory management during the conversation. Of course, you can use ready-made libraries such as 'Langchain', 'Ollama', etc ... But what if you want to implement your own algorithm from scratch?
We present here an approach to managing short-term memory in chatbots, using a combination of storage and automatic summarization techniques to optimize conversational context. The introduced method relies on a dynamic memory structure that limits data size while preserving essential information through intelligent summaries.
This approach not only improves the fluidity of interactions but also ensures contextual continuity during long dialogue sessions. Moreover, the use of asynchronous techniques guarantees that memory management operations do not interfere with the chatbot’s responsiveness.
In this section, we mathematically formalize the management of conversation memory in the chatbot. The memory is structured as a list of pairs representing the exchanges between the user and the bot.
Don't worry, this part's pretty straightforward, just allowing us to model the basic mechanisms so we can transpose them into code.
The conversation memory can be defined as an ordered list of pairs (u_i, d_i), where u_i represents the user’s input and d_i is the bot’s response for the i^th exchange. This list is denoted as C:
C = [(u1, d1), (u2, d2), ..., (un, dn)]
Where n is the total number of exchanges in the current history.
When a new exchange occurs, a new pair (u_{n+1}, d_{n+1}) is added to the memory. If the size of C exceeds a predefined maximum limit M_max, the oldest exchange is removed:
C = if |C| < M_max:
C ∪ {(u_{n+1}, d_{n+1})}
else:
(C \\ {(u1, d1)}) ∪ {(u_{n+1}, d_{n+1})}
To manage memory space and decide when compression is necessary, we calculate the total number of words W(C) in the memory:
W(C) = Σ (|ui| + |di|)
Where |ui| and |di| are the number of words in ui and di, respectively.
When W(C) exceeds a threshold W_max, the memory is compressed to maintain relevant context. This compression is performed by a summarization model S, such as BART:
C_compressed = S(C)
Where C_compressed is the summarized version of the memory, reducing the total number of words while preserving the essence of past interactions.
In this section, we will examine a Python code example that illustrates memory management in a chatbot. The code uses PyTorch and Hugging Face's transformers to manage and compress the conversation history.
We begin by checking if a GPU is available, which allows for faster processing if necessary.
import torch from transformers import pipeline import logging
if torch.cuda.is_available():
device: int = 0
else:
device: int = -1
MAX_MEMORY_SIZE: int = 2000
The ChatbotMemory class manages the conversation history and performs update and compression operations. Every time update_memory is called, the memory text is counted and processed as needed.
class ChatbotMemory:
def init(self, conv: list = []):
self.conversation_history = conv
def update_memory(self, user_input: str, bot_response: str) -> None:
self.conversation_history.append(f"'user': {user_input}, 'bot': {bot_response}")
if memory_counter(self.conversation_history) > 1000:
self.conversation_history = compressed_memory(self.conversation_history)
logging.info("Memory compressed.")
if len(self.conversation_history) > MAX_MEMORY_SIZE:
self.conversation_history.pop(0)
logging.info("Memory reduced.")
return 0
def get_memory(self):
return self.conversation_history
The _get_compressed_memory function uses the BART model to summarize the conversation history.
def _get_compressed_memory(sentence: str) -> str:
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
summary = summarizer(sentence, max_length=50, min_length=5, do_sample=False)
return summary[0]['summary_text']
The compressed_memory function applies the _get_compressed_memory function to each segment of the conversation history. This is optimized by processing in batches. This method is separate from _get_compressed_memory to allow for the introduction of new compression methods.
def compressed_memory(conv_hist: list) -> list:
return [_get_compressed_memory(' '.join(conv_hist[i:i+5])) for i in range(0, len(conv_hist), 5)]
The memory_counter function counts the total number of words in the history. (Note that it might be interesting to perform this step with tokens instead of words.)
def memory_counter(conv_hist: list) -> int:
st = ''.join(conv_hist)
return len(st.split())
This code establishes an efficient framework for memory management in a chatbot, using compression techniques to maintain relevant context and improve the overall system performance. Using summarization models like BART ensures that even when memory is compressed, the essential context is preserved.
I hope this little demonstration has helped you understand that it's quite simple to model common processes and implement them in your code. It's a basic but essential step that allows you to grasp how things work: