Learn core database concepts by implementing a Python key-value store with crash recovery and efficient writes.
In this tutorial, we’ll build a simple but functional database from scratch with Python. Through this hands-on project, we’ll explore core database concepts like Write-Ahead Logging (WAL), Sorted String Tables (SSTables), Log-Structured Merge (LSM) trees, and other optimization techniques. By the end, you’ll have a deeper understanding of how real databases work under the hood.
Before diving into implementation, let’s understand what a database is and why we’re building one. Building a database from scratch is not just an educational exercise — it helps us understand the tradeoffs and decisions that go into database design, making us better at choosing and using databases in our own applications. Whether you’re using MongoDB, PostgreSQL, or any other database, the concepts we’ll explore form the foundation of their implementations.
A database is like a digital filing cabinet that helps us:
In this tutorial, we’ll build a simple database that stores key-value pairs. Think of it like a Python dictionary that saves to disk.
For example:
# Regular Python dictionary (loses data when program ends)
my_dict = {
"user:1": {"name": "Alice", "age": 30},
"user:2": {"name": "Bob", "age": 25}
}
# Our database (keeps data even after program ends)
our_db.set("user:1", {"name": "Alice", "age": 30})
our_db.get("user:1") # Returns: {"name": "Alice", "age": 30}
You might wonder: “Why not just save a Python dictionary to a file?” Let’s start there to understand the problems we’ll need to solve.
import os
import pickle
from typing import Any, Dict, Optional
class SimpleStore:
def __init__(self, filename: str):
self.filename = filename
self.data: Dict[str, Any] = {}
self._load()
def _load(self):
"""Load data from disk if it exists"""
if os.path.exists(self.filename):
with open(self.filename, 'rb') as f:
self.data = pickle.load(f)
def _save(self):
"""Save data to disk"""
with open(self.filename, 'wb') as f:
pickle.dump(self.data, f)
def set(self, key: str, value: Any):
self.data[key] = value
self._save() # Write to disk immediately
def get(self, key: str) -> Optional[Any]:
return self.data.get(key)
Let’s break down what this code does:
init(self, filename)
: Creates a new database using a file to store data. Filename is where we'll save our data. self.data is our in-memory dictionary.
_load(self)
: Reads saved data from disk when we start. Uses Python’s pickle module to convert saved bytes back into Python objects The underscore (_) means it’s an internal method not meant to be called directly.
_save(self)
: Writes all data to disk. Uses Pickle to convert Python objects into bytes. Called every time we changed data.
set(self, key, value
) and get(self, key)
: Works like a dictionary’s [] operator, set saves to disk, and immediately gets returns. None if the key doesn't exist.
This simple implementation has several problems:
Crask risk:
# If the program crashes here, we lose data:
db.data["key"] = "value" # Changed in memory
# ... crash before _save() ...
Performance issue:
# This will be very slow because it writes the entire database for each set
for i in range(1000):
db.set(f"key:{i}", f"value:{i}") # Writes ALL data 1000 times!
Memory limitation:
# All data must fit in RAM:
huge_data = {"key": "x" * 1000000000} # 1GB of data
db.set("huge", huge_data) # Might run out of memory!
Concurrency issue:
# If two programs do this at once:
db.set("counter", db.get("counter") + 1)
# They might both read "5" and both write "6"
# Instead of one writing "6" and one writing "7"
First, to ensure data persistence and recovery capabilities, we will use Write-Ahead Logging. Write-ahead logging is like keeping a diary of everything you’re going to do before you do it. If something goes wrong halfway through, you can look at your diary and finish the job. WAL is a reliability mechanism that records all changes before they are applied to the database.
Example:
# WAL entries are like diary entries:
Entry 1: "At 2024-03-20 14:30:15, set user:1 to {"name": "Alice"}"
Entry 2: "At 2024-03-20 14:30:16, set user:2 to {"name": "Bob"}"
Entry 3: "At 2024-03-20 14:30:17, delete user:1"
class WALEntry:
def __init__(self, operation: str, key: str, value: Any):
self.timestamp = datetime.utcnow().isoformat()
self.operation = operation # 'set' or 'delete'
self.key = key
self.value = value
def serialize(self) -> str:
"""Convert the entry to a string for storage"""
return json.dumps({
'timestamp': self.timestamp,
'operation': self.operation,
'key': self.key,
'value': self.value
})
Let’s understand what each part means:
timestamp
: When the operation happened
# Example timestamp
"2024-03-20T14:30:15.123456" # ISO format: readable and sortable
operation
: What we're doing
# Example operations
{"operation": "set", "key": "user:1", "value": {"name": "Alice"}}
{"operation": "delete", "key": "user:1", "value": null}
serialize
: Converts to string for storage
# Instead of binary pickle format, we use JSON for readability:
{
"timestamp": "2024-03-20T14:30:15.123456",
"operation": "set",
"key": "user:1",
"value": {"name": "Alice"}
}
class DatabaseError(Exception):
"""Base class for database exceptions"""
pass
class WALStore:
def __init__(self, data_file: str, wal_file: str):
self.data_file = data_file
self.wal_file = wal_file
self.data: Dict[str, Any] = {}
self._recover()
def _append_wal(self, entry: WALEntry):
"""Write an entry to the log file"""
try:
with open(self.wal_file, "a") as f:
f.write(entry.serialize() + "\n")
f.flush() # Ensure it's written to disk
os.fsync(f.fileno()) # Force disk write
except IOError as e:
raise DatabaseError(f"Failed to write to WAL: {e}")
def _recover(self):
"""Rebuild database state from WAL if needed"""
try:
# First load the last checkpoint
if os.path.exists(self.data_file):
with open(self.data_file, "rb") as f:
self.data = pickle.load(f)
# Then replay any additional changes from WAL
if os.path.exists(self.wal_file):
with open(self.wal_file, "r") as f:
for line in f:
if line.strip(): # Skip empty lines
entry = json.loads(line)
if entry["operation"] == "set":
self.data[entry["key"]] = entry["value"]
elif entry["operation"] == "delete":
self.data.pop(entry["key"], None)
except (IOError, json.JSONDecodeError, pickle.PickleError) as e:
raise DatabaseError(f"Recovery failed: {e}")
def set(self, key: str, value: Any):
"""Set a key-value pair with WAL"""
entry = WALEntry("set", key, value)
self._append_wal(entry)
self.data[key] = value
def delete(self, key: str):
"""Delete a key with WAL"""
entry = WALEntry("delete", key, None)
self._append_wal(entry)
self.data.pop(key, None)
def checkpoint(self):
"""Create a checkpoint of current state"""
temp_file = f"{self.data_file}.tmp"
try:
# Write to temporary file first
with open(temp_file, "wb") as f:
pickle.dump(self.data, f)
f.flush()
os.fsync(f.fileno())
# Atomically replace old file
shutil.move(temp_file, self.data_file)
# Clear WAL - just truncate instead of opening in 'w' mode
if os.path.exists(self.wal_file):
with open(self.wal_file, "r+") as f:
f.truncate(0)
f.flush()
os.fsync(f.fileno())
except IOError as e:
if os.path.exists(temp_file):
os.remove(temp_file)
raise DatabaseError(f"Checkpoint failed: {e}")
In this implementation, we keep track of 2 files: the data_file
and wal_file
. The data_file
serves as our permanent storage where we save all data periodically (like a database backup), while the wal_file
acts as our transaction log where we record every operation before executing it (like a diary of changes).
The recovery process works like this:
Example of recovery sequence:
# data_file contains:
{"user:1": {"name": "Alice"}}
# wal_file contains:
{"operation": "set", "key": "user:2", "value": {"name": "Bob"}}
{"operation": "delete", "key": "user:1"}
# After recovery, self.data contains:
{"user:2": {"name": "Bob"}}
Another important method of this implementation is checkpoint
. It creates a permanent snapshot of the current state. Here is an example of a checkpointing process:
1. Current state:
data_file: {"user:1": {"name": "Alice"}}
wal_file: [set user:2 {"name": "Bob"}]
2. During checkpoint:
data_file: {"user:1": {"name": "Alice"}}
data_file.tmp: {"user:1": {"name": "Alice"}, "user:2": {"name": "Bob"}}
wal_file: [set user:2 {"name": "Bob"}]
3. After checkpoint:
data_file: {"user:1": {"name": "Alice"}, "user:2": {"name": "Bob"}}
wal_file: [] # Empty
MemTables serve as our database’s fast write path, providing quick access to recently written data. They maintain sorted order in memory, enabling efficient reads and range queries while preparing data for eventual persistent storage. You can think of a MemTable as a sorting tray on your desk:
class MemTable:
def __init__(self, max_size: int = 1000):
self.entries: List[Tuple[str, Any]] = []
self.max_size = max_size
def add(self, key: str, value: Any):
"""Add or update a key-value pair"""
idx = bisect.bisect_left([k for k, _ in self.entries], key)
if idx < len(self.entries) and self.entries[idx][0] == key:
self.entries[idx] = (key, value)
else:
self.entries.insert(idx, (key, value))
def get(self, key: str) -> Optional[Any]:
"""Get value for key"""
idx = bisect.bisect_left([k for k, _ in self.entries], key)
if idx < len(self.entries) and self.entries[idx][0] == key:
return self.entries[idx][1]
return None
def is_full(self) -> bool:
"""Check if memtable has reached max size"""
return len(self.entries) >= self.max_size
def range_scan(self, start_key: str, end_key: str) -> Iterator[Tuple[str, Any]]:
"""Scan entries within key range"""
start_idx = bisect.bisect_left([k for k, _ in self.entries], start_key)
end_idx = bisect.bisect_right([k for k, _ in self.entries], end_key)
return iter(self.entries[start_idx:end_idx])
Let’s see how the memtable stays sorted:
# Starting state:
entries = [
("apple", 1),
("cherry", 2),
("zebra", 3)
]
# Adding "banana" = 4:
# 1. Find insertion point (between "apple" and "cherry")
# 2. Insert new entry
# 3. Result:
entries = [
("apple", 1),
("banana", 4),
("cherry", 2),
("zebra", 3)
]
While our implementation uses a simple sorted list with binary search, production databases like LevelDB and RocksDB typically use more sophisticated data structures like Red-Black trees or Skip Lists. We used a simpler approach here to focus on the core concepts, but keep in mind, that real databases need these optimizations for better performance.
While MemTables provide fast writes, they can’t grow indefinitely. We need a way to persist them to disk efficiently. Since our memory is limited, when the MemTable gets full, we need to save it to disk. We use SSTables for this. SSTables (Sorted String Tables) are like sorted, immutable folders of data:
class SSTable:
def __init__(self, filename: str):
self.filename = filename
self.index: Dict[str, int] = {}
if os.path.exists(filename):
self._load_index()
def _load_index(self):
"""Load index from existing SSTable file"""
try:
with open(self.filename, "rb") as f:
# Read index position from start of file
f.seek(0)
index_pos = int.from_bytes(f.read(8), "big")
# Read index from end of file
f.seek(index_pos)
self.index = pickle.load(f)
except (IOError, pickle.PickleError) as e:
raise DatabaseError(f"Failed to load SSTable index: {e}")
def write_memtable(self, memtable: MemTable):
"""Save memtable to disk as SSTable"""
temp_file = f"{self.filename}.tmp"
try:
with open(temp_file, "wb") as f:
# Write index size for recovery
index_pos = f.tell()
f.write(b"\0" * 8) # Placeholder for index position
# Write data
for key, value in memtable.entries:
offset = f.tell()
self.index[key] = offset
entry = pickle.dumps((key, value))
f.write(len(entry).to_bytes(4, "big"))
f.write(entry)
# Write index at end
index_offset = f.tell()
pickle.dump(self.index, f)
# Update index position at start of file
f.seek(index_pos)
f.write(index_offset.to_bytes(8, "big"))
f.flush()
os.fsync(f.fileno())
# Atomically rename temp file
shutil.move(temp_file, self.filename)
except IOError as e:
if os.path.exists(temp_file):
os.remove(temp_file)
raise DatabaseError(f"Failed to write SSTable: {e}")
def get(self, key: str) -> Optional[Any]:
"""Get value for key from SSTable"""
if key not in self.index:
return None
try:
with open(self.filename, "rb") as f:
f.seek(self.index[key])
size = int.from_bytes(f.read(4), "big")
entry = pickle.loads(f.read(size))
return entry[1]
except (IOError, pickle.PickleError) as e:
raise DatabaseError(f"Failed to read from SSTable: {e}")
def range_scan(self, start_key: str, end_key: str) -> Iterator[Tuple[str, Any]]:
"""Scan entries within key range"""
keys = sorted(k for k in self.index.keys() if start_key <= k <= end_key)
for key in keys:
value = self.get(key)
if value is not None:
yield (key, value)
File format explanation:
File layout:
[size1][entry1][size2][entry2]...
Example:
[0x00000020][{"key": "apple", "value": 1}][0x00000024][{"key": "banana", "value": 4}]...
Now, we’ll combine everything we’ve learned (WAL, MemTable, and SSTables) into a simplified version of a Log-Structured Merge Tree (LSM Tree). While a true LSM tree uses multiple levels, we’ll start with a basic flat structure in Part 1 and upgrade to a proper leveled implementation in Part 2 of this series.
Imagine you’re organizing papers in an office:
This is like the fast “write path” of the database.
Papers can be added quickly without worrying about organization yet.
The inbox is kept in memory for quick access.
The papers are now sorted and stored efficiently.
Once in a folder, these papers never change (immutable).
Each folder has its own index for quick lookups.
In our Part 1 implementation, we’ll simply merge all folders into one.
This is a simplified approach compared to real databases. While functional, it’s not as efficient as proper leveled merging.
Here’s how we implement our simplified version:
class LSMTree:
def __init__(self, base_path: str):
self.base_path = Path(base_path)
try:
# Check if path exists and is a file
if self.base_path.exists() and self.base_path.is_file():
raise DatabaseError(f"Cannot create database: '{base_path}' is a file")
self.base_path.mkdir(parents=True, exist_ok=True)
except (OSError, FileExistsError) as e:
raise DatabaseError(
f"Failed to initialize database at '{base_path}': {str(e)}"
)
# Our "Inbox" for new data
self.memtable = MemTable(max_size=1000)
# Our "Folders" of sorted data
self.sstables: List[SSTable] = []
self.max_sstables = 5 # Limit on number of SSTables
self.lock = RLock()
self.wal = WALStore(
str(self.base_path / "data.db"), str(self.base_path / "wal.log")
)
self._load_sstables()
if len(self.sstables) > self.max_sstables:
self._compact()
def _load_sstables(self):
"""Load existing SSTables from disk"""
self.sstables.clear()
for file in sorted(self.base_path.glob("sstable_*.db")):
self.sstables.append(SSTable(str(file)))
To ensure thread safety in our database, we use locks to prevent multiple threads from causing problems:
# Example of why we need locks:
# Without locks:
Thread 1: reads data["x"] = 5
Thread 2: reads data["x"] = 5
Thread 1: writes data["x"] = 6
Thread 2: writes data["x"] = 7 # Last write wins, first update lost!
# With locks:
Thread 1: acquires lock
Thread 1: reads data["x"] = 5
Thread 1: writes data["x"] = 6
Thread 1: releases lock
Thread 2: acquires lock
Thread 2: reads data["x"] = 6
Thread 2: writes data["x"] = 7
Thread 2: releases lock
Let’s see how writing data works step by step:
def set(self, key: str, value: Any):
"""Set a key-value pair"""
with self.lock:
if not isinstance(key, str):
raise ValueError("Key must be a string")
# 1. Safety first: Write to WAL
self.wal.set(key, value)
# 2. Write to memory table (fast!)
self.memtable.add(key, value)
# 3. If memory table is full, save to disk
if self.memtable.is_full():
self._flush_memtable()
def _flush_memtable(self):
"""Flush memtable to disk as new SSTable"""
if not self.memtable.entries:
return # Skip if empty
# Create new SSTable with a unique name
sstable = SSTable(str(self.base_path / f"sstable_{len(self.sstables)}.db"))
sstable.write_memtable(self.memtable)
# Add to our list of SSTables
self.sstables.append(sstable)
# Create fresh memory table
self.memtable = MemTable()
# Create a checkpoint in WAL
self.wal.checkpoint()
# Compact if we have too many SSTables
if len(self.sstables) > self.max_sstables:
self._compact()
Example of how data flows:
# Starting state:
memtable: empty
sstables: []
# After db.set("user:1", {"name": "Alice"})
memtable: [("user:1", {"name": "Alice"})]
sstables: []
# After 1000 more sets (memtable full)...
memtable: empty
sstables: [sstable_0.db] # Contains sorted data
# After 1000 more sets...
memtable: empty
sstables: [sstable_0.db, sstable_1.db]
Reading needs to check multiple places, newest to oldest:
def get(self, key: str) -> Optional[Any]:
"""Get value for key"""
with self.lock:
if not isinstance(key, str):
raise ValueError("Key must be a string")
# 1. Check memtable first (newest data)
value = self.memtable.get(key)
if value is not None:
return value
# 2. Check each SSTable, newest to oldest
for sstable in reversed(self.sstables):
value = sstable.get(key)
if value is not None:
return value
# 3. Not found anywhere
return None
def range_query(self, start_key: str, end_key: str) -> Iterator[Tuple[str, Any]]:
"""Perform a range query"""
with self.lock:
# Get from memtable
for key, value in self.memtable.range_scan(start_key, end_key):
yield (key, value)
# Get from each SSTable
seen_keys = set()
for sstable in reversed(self.sstables):
for key, value in sstable.range_scan(start_key, end_key):
if key not in seen_keys:
seen_keys.add(key)
if value is not None: # Skip tombstones
yield (key, value)
Example of reading:
# Database state:
memtable: [("user:3", {"name": "Charlie"})]
sstables: [
sstable_0.db: [("user:1", {"name": "Alice"})],
sstable_1.db: [("user:2", {"name": "Bob"})]
]
# Reading "user:3" -> Finds it in memtable
# Reading "user:1" -> Checks memtable, then finds in sstable_0.db
# Reading "user:4" -> Checks everywhere, returns None
def _compact(self):
"""Merge multiple SSTables into one"""
try:
# Create merged memtable
merged = MemTable(max_size=float("inf"))
# Merge all SSTables
for sstable in self.sstables:
for key, value in sstable.range_scan("", "~"): # Full range
merged.add(key, value)
# Write merged data to new SSTable
new_sstable = SSTable(str(self.base_path / "sstable_compacted.db"))
new_sstable.write_memtable(merged)
# Remove old SSTables
old_files = [sst.filename for sst in self.sstables]
self.sstables = [new_sstable]
# Delete old files
for file in old_files:
try:
os.remove(file)
except OSError:
pass # Ignore deletion errors
except Exception as e:
raise DatabaseError(f"Compaction failed: {e}")
# Before compaction:
sstables: [
sstable_0.db: [("apple", 1), ("banana", 2)],
sstable_1.db: [("banana", 3), ("cherry", 4)],
sstable_2.db: [("apple", 5), ("date", 6)]
]
# After compaction:
sstables: [
sstable_compacted.db: [
("apple", 5), # Latest value wins
("banana", 3), # Latest value wins
("cherry", 4),
("date", 6)
]
]
We also need other methods, such as deleting and closing the database instance:
def delete(self, key: str):
"""Delete a key"""
with self.lock:
self.wal.delete(key)
self.set(key, None) # Use None as tombstone
def close(self):
"""Ensure all data is persisted to disk"""
with self.lock:
if self.memtable.entries: # If there's data in memtable
self._flush_memtable()
self.wal.checkpoint() # Ensure WAL is up-to-date
Here’s how to use what we’ve built:
# Create database in the 'mydb' directory
db = LSMTree("./mydb")
# Store some user data
db.set("user:1", {
"name": "Alice",
"email": "[email protected]",
"age": 30
})
# Read it back
user = db.get("user:1")
print(user['name']) # Prints: Alice
# Store many items
for i in range(1000):
db.set(f"item:{i}", {
"name": f"Item {i}",
"price": random.randint(1, 100)
})
# Range query example
print("\nItems 10-15:")
for key, value in db.range_query("item:10", "item:15"):
print(f"{key}: {value}")
Let’s put our database through some basic tests to understand its behavior:
def test_basic_operations(db_path):
db = LSMTree(db_path)
# Test single key-value
db.set("test_key", "test_value")
assert db.get("test_key") == "test_value"
# Test overwrite
db.set("test_key", "new_value")
assert db.get("test_key") == "new_value"
# Test non-existent key
assert db.get("missing_key") is None
def test_delete_operations(db_path):
db = LSMTree(db_path)
# Test delete existing key
db.set("key1", "value1")
assert db.get("key1") == "value1"
db.delete("key1")
assert db.get("key1") is None
# Test delete non-existent key
db.delete("nonexistent_key") # Should not raise error
# Test set after delete
db.set("key1", "new_value")
assert db.get("key1") == "new_value"
def test_range_query(db_path):
db = LSMTree(db_path)
# Insert test data
test_data = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}
for k, v in test_data.items():
db.set(k, v)
# Test range query
results = list(db.range_query("b", "d"))
assert len(results) == 3
assert results == [("b", 2), ("c", 3), ("d", 4)]
# Test empty range
results = list(db.range_query("x", "z"))
assert len(results) == 0
def test_concurrent_operations(db_path):
db = LSMTree(db_path)
def writer_thread():
for i in range(100):
db.set(f"thread_key_{i}", f"thread_value_{i}")
sleep(0.001) # Small delay to increase chance of concurrency issues
# Create multiple writer threads
threads = [Thread(target=writer_thread) for _ in range(5)]
# Start all threads
for thread in threads:
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Verify all data was written correctly
for i in range(100):
assert db.get(f"thread_key_{i}") == f"thread_value_{i}"
def test_recovery(db_path):
"""Test database recovery with optimized operations"""
# Create a database instance
db1 = LSMTree(db_path)
# Add some data to the database
db1.set("key1", "value1")
db1.set("key2", "value2")
assert db1.get("key1") == "value1"
assert db1.get("key2") == "value2"
# Close the database to force a flush
db1.close()
# Create a new instance of the database
db2 = LSMTree(db_path)
# Verify data is still present after recovery
value1 = db2.get("key1")
value2 = db2.get("key2")
assert value1 == "value1"
assert value2 == "value2"
db2.close()
Running these tests helps ensure our database is working as expected!
In this first part of our journey to build a database from scratch, we’ve implemented a basic but functional key-value store with several important database concepts:
Our implementation can handle basic operations like setting values, retrieving them, and performing range queries while maintaining data consistency and surviving crashes. By converting random writes into sequential ones, LSM Trees excel at quickly ingesting large amounts of data. This is why databases like RocksDB (Facebook), Apache Cassandra (Netflix, Instagram), and LevelDB (Google) use LSM Trees for their storage engines.
In contrast, traditional B-Tree structures (used by PostgreSQL and MySQL) offer better-read performance but may struggle with heavy write loads. We’ll explore B-Tree implementations in future posts to better understand these trade-offs.
Current Limitations
While our database works, it has several limitations that we’ll address in the coming parts:
Storage and Performance:
Concurrency:
In Part 2, we’ll tackle some of these limitations by implementing level-based compaction, bloom filters for faster lookups, and basic transaction support. Stay tuned.
The complete source code for this tutorial is available on
If you want to dive deeper into these concepts before Part 2, here are some resources: