paint-brush
Improving Text-to-SQL with a Fine-Tuned 7B LLM for DB Interactionsby@yi
118 reads

Improving Text-to-SQL with a Fine-Tuned 7B LLM for DB Interactions

by Yi AiOctober 2nd, 2024
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

In this article, I’ll walk you through the process of fine-tuning a 7B model to handle SQL generation tasks more effectively
featured image - Improving Text-to-SQL with a Fine-Tuned 7B LLM for DB Interactions
Yi Ai HackerNoon profile picture

I have encountered challenges when using 7B LLMs for SQL generation tasks, particularly when working with the company’s databases. These models often struggle to generate accurate SQL queries, even when provided with the database schema and table relationships in the context. To address this challenge, fine-tuning a 7B model using QLoRA on a custom dataset tailored to your specific database schema is an effective approach.


In this article, I’ll walk you through the process of fine-tuning a 7B model to handle SQL generation tasks more effectively, and how you can integrate the fine-tuned model into a LangChain-based application for real-time database interactions.

Overview


Before we dive into the details, let’s outline the key steps we’ll be following in this guide:

  1. Prepare a custom dataset based on a database schema.
  2. Fine-tune the 7B model using QLoRA technique.
  3. Evaluate the performance of the fine-tuned model.
  4. Integrate the model into a LangChain application for SQL-based database interaction.


By following this guide, you’ll be able to build a question-answering application for SQL databases using a fine-tuned Mistral 7B model, optimized for generating SQL queries based on your specific database schema.

Step 1: Preparing Your Custom Dataset

To fine-tune the model effectively, you need a high-quality dataset that reflects your database’s structure. Let’s consider a simple customer management database with the following tables:


  1. Customer
  2. Address
  3. Contact

Sample DDL for Customer Table


CREATE TABLE customer (
    customer_key INT PRIMARY KEY,
    source VARCHAR(50),
    full_name VARCHAR(100),
    created_date DATETIME,
    updated_date DATETIME,
    gender VARCHAR(10),
    dateofbirth DATE
);

CREATE TABLE address (
    address_key INT PRIMARY KEY,
    customer_key INT,
    street_address VARCHAR(200),
    city VARCHAR(100),
    state VARCHAR(50),
    postal_code VARCHAR(20),
    country VARCHAR(50),
    is_primary BOOLEAN,
    created_date DATETIME,
    updated_date DATETIME,
    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)
);

CREATE TABLE contact (
    contact_key INT PRIMARY KEY,
    customer_key INT,
    email VARCHAR(100),
    phone VARCHAR(20),
    created_date DATETIME,
    updated_date DATETIME,
    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)
);

Creating Text-to-SQL Samples

To generate the dataset for fine-tuning, use Claude Sonnet or other LLMs to create text-to-SQL samples. Here’s a prompt format you can use to guide the model in creating SQL queries:

Sample Format:
{
  "instruction": "I want you to act as a SQL terminal in front of an example database. You need only to return the SQL command to me. Below is an instruction that describes a task. Write a response that appropriately completes the request.
## Instruction:
[Database description]
",
  "input": "### Input:
[Natural language question]
### Response:",
  "output": "[Corresponding SQL query]"
}
Sample Data Point:
    {
      "instruction": "You are a powerful text-to-SQL model. Your task is to generate SQL queries based on the following schema for a customer database:\n\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY,\n    source VARCHAR(50),\n    full_name VARCHAR(100),\n    created_date DATETIME,\n    updated_date DATETIME,\n    gender VARCHAR(10),\n    dateofbirth DATE\n);\n\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT,\n    street_address VARCHAR(200),\n    city VARCHAR(100),\n    state VARCHAR(50),\n    postal_code VARCHAR(20),\n    country VARCHAR(50),\n    is_primary BOOLEAN,\n    created_date DATETIME,\n    updated_date DATETIME,\n    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)\n);\n\nCREATE TABLE contact (\n    contact_key INT PRIMARY KEY,\n    customer_key INT,\n    email VARCHAR(100),\n    phone VARCHAR(20),\n    created_date DATETIME,\n    updated_date DATETIME,\n    FOREIGN KEY (customer_key) REFERENCES customer(customer_key)\n);",
      "input": "List the full names of customers who have both an email and a phone number.",
      "output": "SELECT DISTINCT c.full_name\nFROM customer c\nJOIN contact ct ON c.customer_key = ct.customer_key\nWHERE ct.email IS NOT NULL AND ct.phone IS NOT NULL;"
    }
The database contains three tables: customer, address, and contact.

Table 'customer' has columns:
customer_key (INT, primary key)
source (VARCHAR(50))
full_name (VARCHAR(100))
created_date (DATETIME)
updated_date (DATETIME)
gender (VARCHAR(10))
dateofbirth (DATE)

Table 'address' has columns:
address_key (INT, primary key)
customer_key (INT, foreign key referencing customer(customer_key))
street_address (VARCHAR(200))
city (VARCHAR(100))
state (VARCHAR(50))
postal_code (VARCHAR(20))
country (VARCHAR(50))
is_primary (BOOLEAN)
created_date (DATETIME)
updated_date (DATETIME)

Table 'contact' has columns:
contact_key (INT, primary key)
customer_key (INT, foreign key referencing customer(customer_key))
email (VARCHAR(100))
phone (VARCHAR(20))
created_date (DATETIME)
updated_date (DATETIME)

Please generate 100 samples in JSON file based on the provided database schema and example. For each sample, ensure that:
Instruction: include only the necessary table definitions in the instruction based on the SQL query in the output
Input: Contains a natural language question about the data.
Output: Provides the corresponding SQL query that answers the question. 
The questions should cover topics such as data analysis, aggregation, address searches, customer searches, contact searches, and reporting.


Repeat this process, generate about 200 to 500 samples covering various SQL tasks, including customer queries, address lookups, and data aggregation. These samples should be saved in JSON format and split into train.json and eval.json for training and evaluation respectively.


Dataset example:

[
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about the customer database based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY,\n    full_name VARCHAR(100),\n    dateofbirth DATE\n);",
    "input": "List the full names of all customers born after January 1, 1990.",
    "output": "SELECT full_name FROM customer WHERE dateofbirth > '1990-01-01';"
  },
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about the address data based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT,\n    city VARCHAR(100)\n);",
    "input": "Find all unique cities where customers reside.",
    "output": "SELECT DISTINCT city FROM address;"
  },
    {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about customer emails based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY\n);\n\nCREATE TABLE contact (\n    contact_key INT PRIMARY KEY,\n    customer_key INT,\n    email VARCHAR(100)\n);",
    "input": "Find customers who have not provided an email address.",
    "output": "SELECT c.customer_key FROM customer c LEFT JOIN contact ct ON c.customer_key = ct.customer_key WHERE ct.email IS NULL;"
  },
  {
    "instruction": "You are a powerful text-to-SQL model. Your job is to answer questions about customers without addresses based on the provided SCHEMA.\nYou must output the SQL query that answers the question. SCHEMA:\nCREATE TABLE customer (\n    customer_key INT PRIMARY KEY\n);\n\nCREATE TABLE address (\n    address_key INT PRIMARY KEY,\n    customer_key INT\n);",
    "input": "List customer keys of customers who have no address on file.",
    "output": "SELECT c.customer_key FROM customer c LEFT JOIN address a ON c.customer_key = a.customer_key WHERE a.customer_key IS NULL;"
  },
...
]

Step 2: Fine-Tuning the Model

Once your dataset is ready, the next step is to fine-tune the 7B model using QLoRA. There are several methods available for fine-tuning models:


  1. Hugging Face TRL & SFTTrainer.
  2. DB-GPT-Hub: A good choice for fine-tuning models on public text-to-SQL datasets like Spider.
  3. LitGPT: A lightweight and efficient framework for fast fine-tuning, pretraining, and deployment of LLMs.


For our use case, we’ll use LitGPT to fine-tune our 7B model on the customer database schema.


Steps to Fine-Tune the Model:

  1. Install LitGPT.
pip install litgpt


2. Download model weights.

litgpt download mistralai/Mistral-7B-Instruct-v0.3 --access_token=xxxxxx


3. Run the fine-tuning process using 4-bit quantization.

litgpt finetune_lora \
    checkpoints/mistralai/Mistral-7B-Instruct-v0.3 \
    --data JSON \
    --data.json_path train.json \
    --out_dir finetuned \
    --precision bf16-true \
    --quantize "bnb.nf4" \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --train.global_batch_size 4 \
    --train.micro_batch_size 1 \
    --train.max_steps 1000 \
    --train.save_interval 200 \
    --eval.interval 50 \
    --train.lr_warmup_steps 100 \
    --train.max_seq_length 2048 \
    --optimizer.learning_rate 2e-4 \
    --optimizer.weight_decay 0.01 \
    --optimizer.betas 0.9 \
    --data.val_split_fraction 0.1


For more detailed instructions, refer to the official LitGPT documentation.

Step 3: Evaluating the Fine-Tuned Model

After fine-tuning, it’s crucial to evaluate the model’s performance. We’ll use the Token Match score metric for this purpose:

  1. Use the evaluate.json file created in step 1, which contains sample SQL queries.


  2. Convert the merged weights from LitGPT (/finetuned/final/lit_model.pth) into the Hugging Face Transformers format.

litgpt convert_from_litgpt finetuned/final out/hf-mistral-7b/converted


  1. Run the evaluation script.
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import re
import sqlparse
from sklearn.metrics import accuracy_score

login(token="XXXXXXX")

# Load the fine-tuned LoRA model
model_path = "out/hf-mistral-7b/converted"
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
state_dict = torch.load(f"{model_path}/model.pth")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", state_dict=state_dict)

model.to("cuda:0")

# Load the evaluation dataset
with open("evaluate.json", "r") as f:
    eval_data = json.load(f)

def normalize_sql(sql):
    # Remove comments
    sql = re.sub(r'--.*$', '', sql, flags=re.MULTILINE)

    sql = ' '.join(sql.split())

    parsed = sqlparse.parse(sql)[0]
    return str(parsed).lower() 

def exact_match_score(prediction, reference):
    return normalize_sql(prediction) == normalize_sql(reference)

def token_match_score(prediction, reference):
    pred_tokens = set(re.findall(r'\b\w+\b', normalize_sql(prediction)))
    ref_tokens = set(re.findall(r'\b\w+\b', normalize_sql(reference)))
    return len(pred_tokens.intersection(ref_tokens)) / len(ref_tokens) if ref_tokens else 0

def evaluate_model(model, tokenizer, eval_data):
    exact_matches = []
    token_match_scores = []

    for item in eval_data:
        instruction = item["instruction"]   
        input_question = item.get("input", "")   
        expected_output = item["output"]  

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": input_question},
        ]

        encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", padding_side='left')

        model_inputs = encodeds.to("cuda:0")

        outputs = model.generate(model_inputs, max_new_tokens=150)
        
        decoded = tokenizer.batch_decode(outputs[:, model_inputs.shape[1]:], skip_special_tokens=True)

        sql_query = decoded[0]

        print(f"Instruction: {instruction}")
        print(f"Input Question: {input_question}")
        print(f"Expected Output: {expected_output}")
        print(f"Generated Output: {sql_query}")
        print("=" * 50)

        # Compute metrics
        exact_matches.append(exact_match_score(sql_query, expected_output))
        token_match_scores.append(token_match_score(sql_query, expected_output))

    avg_token_match_score = sum(token_match_scores) / len(token_match_scores)

    print(f"Average Token Match Score: {avg_token_match_score:.4f}")

evaluate_model(model, tokenizer, eval_data)


This evaluation method can give you a better understanding of your model’s performance on text-to-SQL tasks.


As you can see, the score of 0.8786 is quite good for a text-to-SQL model.

Step 4: Building a Database Interaction RAG using LangChain

With a successfully fine-tuned and evaluated model, we can now integrate it into a LangChain application to build a database interaction application.


Key Components for the LangChain Integration:

  1. Install LangChain and llama.cpp.
set FORCE_CMAKE=1 && set CMAKE_ARGS=-DGGML_CUDA=on && pip install llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/124


2. Convert the Hugging Face model to the GGUF format.

python convert_hf_to_gguf.py /path/to/hf-model --outfile custom-mistral-7b.gguf --outtype f16


3. (Optional) Quantize the model to reduce size.

./llama-quantize ./custom-mistral-7b.gguf ./custom-mistral-7b-Q5_K_M.gguf Q5_K_M


I suggest using Q5_K_M because it preserves most of the model’s performance. OR you can choose Q4_K_M if you want to save some memory.


  1. Install LangChain and the required libraries.


  2. Implement the LangChain SQL chain.

import re
from langchain.sql_database import SQLDatabase
from langchain_community.llms import LlamaCpp
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts.chat import ChatPromptTemplate


# Initialize the LlamaCpp language model with specified parameters
llm = LlamaCpp(
    model_path="./custom-mistral-7b-Q5_K_M.gguf",  # Path to the model file
    max_tokens=2048,  # Maximum number of tokens in the response
    n_ctx=6144,  # Context size
    verbose=True, 
    temperature=0,  
)

# Define the dataset and SQLAlchemy connection URL
dataset = "customer"
sqlalchemy_url = (
    f"postgresql://db_user:db_pass@db_host:5432"  # Replace with actual credentials and host
)

# Initialize the SQLDatabase object with specified schema and tables
db = SQLDatabase.from_uri(
    sqlalchemy_url,
    schema=dataset,
    include_tables=['customer', 'address', 'contact']  # Tables to include in the database
)

# Create the SQL query generation chain using the language model and database
gen_query = create_sql_query_chain(llm, db)

# Convert datetime objects in strings to a specific format
def convert_dates(obj):
    response_str = re.sub(
        r'datetime\.date\((\d+),\s*(\d+),\s*(\d+)\)',
        r"'\1-\2-\3'",
        obj
    )
    response_str = re.sub(
        r'datetime\.datetime\((\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+)(?:,\s*(\d+))?\)',
        r"'\1-\2-\3 \4:\5:\6.\7'",
        response_str
    )
    return response_str

# Route the SQL query based on its content
def route(sql_query: str):
    logging.info(f"Routing query: {sql_query}")  # Log the query for debugging
    if sql_query.get("query") == "I don't know":
        logging.warning("Unknown query detected.")  # Warn if the query is unknown
        return sql_query  # Return the original query
    else:
        return db_opt_chain  # Route to the database operation chain

# Handle cases where the agent responds with "I don't know"
def handle_dont_know(result):
    if isinstance(result, dict) and result.get("query") == "I don't know":
        return "I can only provide information related to our customer data."  # Custom response
    return result  # Return the original result if not "I don't know"

# Custom function to execute the SQL query and process the result
def custom_execute_query_runnable(result: dict) -> dict:
    return {
        **result,  # Include existing result data
        'result': convert_dates(
            db.run_no_throw(command=result["query"], include_columns=True)  # Execute the query without throwing exceptions
        )
    }

# Template for generating the final natural language response
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

# Define the prompt for generating the response
prompt_response = ChatPromptTemplate.from_template(template)

# Define the database operation chain with the custom execution function and response prompt
db_opt_chain = (
    RunnableLambda(custom_execute_query_runnable)  # Execute the query
    | answer_prompt  # Generate the answer based on the query result
    | llm  # Use the language model to format the answer
)

# Combine all components into the full execution chain
full_chain = (
    RunnablePassthrough().assign(query=gen_query)  # Pass through the query generation
    | RunnableLambda(route)  # Route the query appropriately
    | RunnableLambda(handle_dont_know)  # Handle "I don't know" responses
    | StrOutputParser()  # Parse the final output as a string
)

# Example user question to invoke the chain
user_question = 'Get the minimum and maximum age of customers'
full_chain.invoke({"question": user_question})  # Execute the chain with the user question


Note: The sample code provided in this guide is intended as an example. You need to adjust the code to fit your specific use case, such as modifying database connection strings, schema definitions, or tuning model parameters based on your data and infrastructure requirements.


This implementation creates a chain that uses the fine-tuned LLM to generate SQL, executes the query against your database, and then uses the LLM again to interpret and summarize the results.

Conclusion

By fine-tuning a 7B LLM with a custom dataset tailored to your specific database schema, you can greatly enhance its SQL generation capabilities. When combined with LangChain, this allows you to build a Q&A application for database interactions, even when using smaller language models.


Keep in mind that the effectiveness of your fine-tuned model depends on the quality and diversity of your training dataset. Continuously refining your dataset and the fine-tuning process will lead to improved results over time.