Skip to content

Initial Training

This page walks through the complete initial training process — bootstrapping ChromaDB from zero to a fully functional, schema-aware SQL generator.

Already trained?

If chroma.sqlite3 already exists and contains data (as in the current POC), you can skip to Retraining to add or update examples.


Training Sequence

Follow this order for best results. Each step builds on the previous one:

Step 1: Train on DDL (schema)
Step 2: Train on Documentation (business rules)
Step 3: Train on SQL (question–answer examples)

Step 1: Auto-Train from the Database Schema

The fastest way to bootstrap DDL training is to let the API read it directly from the connected SQLite database.

POST /api/v1/training/from-database HTTP/1.1

Response:

{
  "ddl_count": 1,
  "message": "Successfully ingested 1 DDL statement(s) from the database."
}

import sqlite3
import vanna

conn = sqlite3.connect('dbn-poc-database.db')
df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
for ddl in df_ddl['sql'].to_list():
    vn.train(ddl=ddl)
    print(f"Added DDL: {ddl[:60]}...")

Under the hood, this queries sqlite_master for all CREATE TABLE, CREATE INDEX, and CREATE VIEW statements and embeds each one into ChromaDB.


Step 2: Train on Business Documentation

Add plain-English context that helps Vanna interpret domain terminology.

POST /api/v1/training/documentation HTTP/1.1
Content-Type: application/json

{
  "documentation": "A loan is classified as high-risk when pred_default_prob > 0.5. Low risk is pred_default_prob < 0.25."
}

Response:

{
  "id": "d3f1b2e0-...-doc",
  "message": "Documentation successfully added with id=d3f1b2e0-...-doc"
}

DOCS=(
  "PFI refers to a Participating Financial Institution — a bank or microfinance lender."
  "StartUp_norm = 1 means the borrower's business is less than 3 years old."
  "FirstTimeAccessToCredit_norm = 1 means this is the borrower's first formal loan."
  "ES_Rating is the Environmental and Social risk rating assigned by the PFI."
  "MSMEAnnualTurnover is the borrower's self-reported annual business revenue."
  "loan_to_turnover is AmountGranted divided by MSMEAnnualTurnover."
  "employees_bucket groups NumberOfEmployees: Solo (0), Micro (1-9), Small (10-49)."
)

for doc in "${DOCS[@]}"; do
  curl -s -X POST "http://127.0.0.1:8000/api/v1/training/documentation" \
    -H "Content-Type: application/json" \
    -d "{\"documentation\": \"$doc\"}" | jq .id
done

Step 3: Train on Question–SQL Pairs

These are the most valuable training records. Use verified SQL written by your data analysts.

POST /api/v1/training/sql HTTP/1.1
Content-Type: application/json

{
  "question": "Which sector has the highest average predicted default probability?",
  "sql": "SELECT Sector, AVG(pred_default_prob) AS avg_risk\nFROM msmeloans\nGROUP BY Sector\nORDER BY avg_risk DESC\nLIMIT 5;"
}

Response:

{
  "id": "b1304bfd-1f4d-5401-9361-e2e48ff68cc7-sql",
  "message": "Question–SQL pair successfully added with id=b1304bfd-...-sql"
}

scripts/seed_training.py
import httpx

BASE_URL = "http://127.0.0.1:8000/api/v1/training"

training_pairs = [
    {
        "question": "Which sector has the highest average predicted default probability?",
        "sql": """
            SELECT Sector, AVG(pred_default_prob) AS avg_risk
            FROM msmeloans
            GROUP BY Sector
            ORDER BY avg_risk DESC
            LIMIT 5;
        """
    },
    {
        "question": "Show the top 10 borrowers with the highest predicted default risk.",
        "sql": """
            SELECT FullNames, Sector, State, AmountGranted, pred_default_prob
            FROM msmeloans
            ORDER BY pred_default_prob DESC
            LIMIT 10;
        """
    },
    {
        "question": "What percentage of loans are classified as high-risk (probability > 0.5)?",
        "sql": """
            SELECT
                (COUNT(*) FILTER (WHERE pred_default_prob > 0.5) * 100.0 / COUNT(*))
                AS high_risk_share
            FROM msmeloans;
        """
    },
    {
        "question": "What is the average loan amount by sector?",
        "sql": """
            SELECT Sector, AVG(AmountGranted) AS avg_loan
            FROM msmeloans
            GROUP BY Sector
            ORDER BY avg_loan DESC;
        """
    },
    {
        "question": "Which states have the largest number of startup borrowers?",
        "sql": """
            SELECT State, COUNT(*) AS startup_count
            FROM msmeloans
            WHERE StartUp_norm = 1
            GROUP BY State
            ORDER BY startup_count DESC;
        """
    },
    {
        "question": "What is the distribution of loans by age group?",
        "sql": """
            SELECT age_group, COUNT(*) AS loan_count
            FROM msmeloans
            GROUP BY age_group
            ORDER BY loan_count DESC;
        """
    },
    {
        "question": "Show the average loan-to-turnover ratio by number of employees bucket.",
        "sql": """
            SELECT employees_bucket, AVG(loan_to_turnover) AS avg_ratio
            FROM msmeloans
            GROUP BY employees_bucket
            ORDER BY avg_ratio DESC;
        """
    },
    {
        "question": "Which PFIs have the highest average predicted default risk?",
        "sql": """
            SELECT "PFI ID", AVG(pred_default_prob) AS avg_risk
            FROM msmeloans
            GROUP BY "PFI ID"
            ORDER BY avg_risk DESC
            LIMIT 5;
        """
    },
    {
        "question": "Compare average default probability between startups and non-startups.",
        "sql": """
            SELECT StartUp_norm, AVG(pred_default_prob) AS avg_risk
            FROM msmeloans
            GROUP BY StartUp_norm;
        """
    },
    {
        "question": "What is the average turnover of borrowers flagged as high-risk (>0.5)?",
        "sql": """
            SELECT AVG(MSMEAnnualTurnover) AS avg_turnover_high_risk
            FROM msmeloans
            WHERE pred_default_prob > 0.5;
        """
    },
]

for pair in training_pairs:
    resp = httpx.post(f"{BASE_URL}/sql", json=pair)
    resp.raise_for_status()
    print(f"  ✓ [{resp.json()['id']}] {pair['question'][:60]}...")

print(f"\nDone. {len(training_pairs)} SQL examples added.")

Step 4: Verify the Training Data

Inspect what was stored to confirm everything was ingested correctly:

GET /api/v1/training/data HTTP/1.1

Response:

{
  "count": 12,
  "records": [
    {
      "id": "c1ada977-...-ddl",
      "question": null,
      "content": "CREATE TABLE \"msmeloans\" (...)",
      "training_data_type": "ddl"
    },
    {
      "id": "b1304bfd-...-sql",
      "question": "Which sector has the highest average predicted default probability?",
      "content": "SELECT Sector, AVG(pred_default_prob) AS avg_risk FROM msmeloans ...",
      "training_data_type": "sql"
    }
  ]
}


Initial Training Checklist

  • DDL for all tables ingested (via POST /training/from-database)
  • Business terminology documented (at least 5–10 docs)
  • At least 10 verified question–SQL pairs added
  • Training data verified with GET /training/data
  • Test query with POST /chat/ask returns accurate results