Skip to content

Commit 13648ac

Browse files
authored
Merge pull request #31 from RubyRyn/feature/RAG-pipeline-improvement-v2
feat: add BM25 hybrid search with RRF, update chunking for Notion ing…
2 parents a039a5b + 31f2ecc commit 13648ac

17 files changed

Lines changed: 1088 additions & 326 deletions

.env.example

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ JWT_ALGORITHM=HS256
66
JWT_EXPIRE_MINUTES=1440
77
FRONTEND_URL=http://localhost:5173
88
DATABASE_URL=sqlite:///./workmate.db
9+
10+
# AI / LLM
11+
GEMINI_API_KEY=your-gemini-api-key
12+
VOYAGE_API_KEY=your-voyageai-api-key
13+
14+
# Notion (direct API access — optional, JSON import works without this)
15+
NOTION_TOKEN=your-notion-api-token

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,10 @@ chroma_db/
4646

4747
# Antigravity SKILLS
4848
.agent/
49-
CLAUDE.md
49+
CLAUDE.md
50+
.claude/
51+
52+
# Debugging
53+
Tasks_for_Claude.txt
54+
debug_rag.py
55+
generate_token.py

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ dependencies = [
1919
"python-multipart>=0.0.22",
2020
"pypdf2>=3.0.1",
2121
"sse-starlette>=3.3.2",
22+
"voyageai>=0.3.7",
23+
"bm25s>=0.2.12",
2224
]

src/backend/app.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from src.backend.config import settings
55
from src.backend.database import Base, engine
6-
from src.backend.routers import admin, auth, chat, conversations, upload
6+
from src.backend.routers import admin, auth, conversations, upload
77

88

99
def create_app() -> FastAPI:
@@ -21,7 +21,6 @@ def create_app() -> FastAPI:
2121

2222
app.include_router(auth.router)
2323
app.include_router(admin.router)
24-
app.include_router(chat.router)
2524
app.include_router(conversations.router)
2625
app.include_router(upload.router)
2726

src/backend/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class Settings(BaseSettings):
1111
FRONTEND_URL: str = "http://localhost:5173"
1212
DATABASE_URL: str = "sqlite:///./workmate.db"
1313
GEMINI_API_KEY: str = ""
14+
VOYAGE_API_KEY: str = ""
1415
NOTION_TOKEN: str = ""
1516

1617
model_config = {"env_file": ".env", "extra": "ignore"}

src/backend/dependencies/services.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
import logging
2+
import os
23

34
from fastapi import HTTPException
45

6+
from src.backend.load.bm25_manager import BM25Manager, BM25_INDEX_PATH
57
from src.backend.load.chroma_manager import ChromaManager
8+
from src.backend.load.hybrid_retriever import HybridRetriever
69
from src.backend.llm.gemini_client import GeminiClient
10+
from src.backend.llm.voyage_reranker import VoyageReranker
711

812
logger = logging.getLogger(__name__)
913

1014
_chroma_manager = None
1115
_gemini_client = None
16+
_voyage_reranker = None
17+
_bm25_manager = None
18+
_hybrid_retriever = None
1219

1320

1421
def get_chroma_manager() -> ChromaManager:
@@ -33,3 +40,32 @@ def get_gemini_client() -> GeminiClient:
3340
logger.error(f"Failed to initialize GeminiClient: {e}")
3441
raise HTTPException(status_code=500, detail="LLM configuration missing.")
3542
return _gemini_client
43+
44+
45+
def get_voyage_reranker() -> VoyageReranker:
46+
global _voyage_reranker
47+
if _voyage_reranker is None:
48+
try:
49+
_voyage_reranker = VoyageReranker()
50+
except Exception as e:
51+
logger.error(f"Failed to initialize VoyageReranker: {e}")
52+
raise HTTPException(status_code=500, detail="Reranker configuration missing.")
53+
return _voyage_reranker
54+
55+
56+
def get_bm25_manager() -> BM25Manager:
57+
global _bm25_manager
58+
if _bm25_manager is None:
59+
_bm25_manager = BM25Manager()
60+
if os.path.exists(BM25_INDEX_PATH):
61+
_bm25_manager.load(BM25_INDEX_PATH)
62+
else:
63+
logger.warning(f"BM25 index not found at {BM25_INDEX_PATH}. Run NotionIngestor to build it.")
64+
return _bm25_manager
65+
66+
67+
def get_hybrid_retriever() -> HybridRetriever:
68+
global _hybrid_retriever
69+
if _hybrid_retriever is None:
70+
_hybrid_retriever = HybridRetriever(get_chroma_manager(), get_bm25_manager())
71+
return _hybrid_retriever

src/backend/llm/gemini_client.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,25 @@ def filter_chunks(
3333
) -> List[Dict[str, Any]]:
3434
"""
3535
LLM Re-ranking: Asks Gemini to filter out irrelevant chunks before generation.
36+
Uses sequential chunk_N IDs (chunk_1, chunk_2, ...) for the filter prompt so
37+
the LLM and the matching logic agree on the same ID format.
3638
"""
3739
if not chunks:
3840
return []
3941

40-
prompt = prompts.get_filter_prompt(chunks, user_question)
42+
# Remap chunks to sequential IDs for the filter prompt so the LLM
43+
# returns predictable IDs like "chunk_1, chunk_3" instead of UUIDs.
44+
sequential_id_map: Dict[str, Dict[str, Any]] = {}
45+
remapped_chunks = []
46+
for i, chunk in enumerate(chunks, start=1):
47+
seq_id = f"chunk_{i}"
48+
sequential_id_map[seq_id] = chunk
49+
remapped_chunks.append({**chunk, "chunk_id": seq_id})
50+
51+
prompt = prompts.get_filter_prompt(remapped_chunks, user_question)
4152
cfg = types.GenerateContentConfig(
4253
system_instruction=prompts.FILTER_SYSTEM_INSTRUCTION,
43-
temperature=0.0, # Zero precision for extraction
54+
temperature=0.0,
4455
max_output_tokens=100,
4556
)
4657

@@ -51,40 +62,37 @@ def filter_chunks(
5162
config=cfg,
5263
)
5364
output = getattr(response, "text", "") or ""
54-
print(f"🧠 Re-ranker Output: {output}")
55-
65+
print(f"Re-ranker Output: {output}")
66+
5667
if "NONE" in output.upper():
57-
print("🧠 Re-ranker kept 0 chunks.")
68+
print("Re-ranker kept 0 chunks.")
5869
return []
59-
60-
# Parse the output safely: sometimes Gemini returns a clean comma-separated list of UUIDs
61-
# But sometimes it's lazy and returns just the first part e.g. "30, 4f" instead of "302f24..."
62-
filtered_chunks = []
70+
71+
# Match returned IDs (e.g. "chunk_1, chunk_3") back to original chunks.
6372
output_parts = [p.strip() for p in output.split(",") if p.strip()]
64-
65-
for chunk in chunks:
66-
chunk_id = str(chunk["chunk_id"])
67-
68-
# Check 1: Is the full chunk_id anywhere in the raw output string?
69-
if chunk_id in output:
70-
filtered_chunks.append(chunk)
73+
filtered_chunks = []
74+
seen = set()
75+
for part in output_parts:
76+
# Exact match: "chunk_3"
77+
if part in sequential_id_map and part not in seen:
78+
filtered_chunks.append(sequential_id_map[part])
79+
seen.add(part)
7180
continue
72-
73-
# Check 2: Did the LLM abbreviate the IDs? Check each comma-separated part.
74-
for part in output_parts:
75-
if len(part) >= 2 and chunk_id.startswith(part):
76-
filtered_chunks.append(chunk)
77-
break
78-
79-
print(f"🧠 Re-ranker kept {len(filtered_chunks)}/{len(chunks)} chunks.")
80-
81-
# Fallback to all chunks if zero were matched but it didn't explicitly say "NONE"
81+
# Plain number match: LLM returned "3" instead of "chunk_3"
82+
candidate = f"chunk_{part}"
83+
if candidate in sequential_id_map and candidate not in seen:
84+
filtered_chunks.append(sequential_id_map[candidate])
85+
seen.add(candidate)
86+
87+
print(f"Re-ranker kept {len(filtered_chunks)}/{len(chunks)} chunks.")
88+
89+
# Fallback: if nothing matched but LLM didn't say NONE, return all chunks.
8290
if not filtered_chunks:
83-
return chunks
91+
return chunks
8492
return filtered_chunks
85-
93+
8694
except Exception as e:
87-
logger.warning(f"⚠️ Re-ranking failed (falling back to all chunks): {e}")
95+
logger.warning(f"Re-ranking failed (falling back to all chunks): {e}")
8896
return chunks
8997

9098
def ask_workmate(

src/backend/llm/voyage_reranker.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import os
5+
from typing import Any, Dict, List, Tuple
6+
7+
logger = logging.getLogger(__name__)
8+
9+
DEFAULT_RERANK_MODEL = "rerank-2"
10+
RELEVANCE_THRESHOLD = 0.3
11+
12+
13+
class VoyageReranker:
14+
"""
15+
Reranks retrieved chunks using the VoyageAI rerank API.
16+
Replaces the LLM-based filter_chunks step in the RAG pipeline.
17+
"""
18+
19+
def __init__(
20+
self,
21+
model: str = DEFAULT_RERANK_MODEL,
22+
threshold: float = RELEVANCE_THRESHOLD,
23+
):
24+
api_key = os.getenv("VOYAGE_API_KEY")
25+
if not api_key:
26+
logger.warning(
27+
"[VoyageReranker] VOYAGE_API_KEY not set — reranking disabled. "
28+
"Set VOYAGE_API_KEY in .env to enable VoyageAI reranking."
29+
)
30+
self.client = None
31+
else:
32+
import voyageai
33+
self.client = voyageai.Client(api_key=api_key)
34+
35+
self.model = model
36+
self.threshold = threshold
37+
38+
def rerank(
39+
self,
40+
chunks: List[Dict[str, Any]],
41+
query: str,
42+
top_k: int = 5,
43+
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
44+
"""
45+
Rerank chunks by relevance to the query.
46+
47+
Returns:
48+
final_chunks — top_k chunks above threshold, score field stripped (clean for generation)
49+
scored_chunks — all input chunks sorted by score descending, with rerank_score added
50+
"""
51+
if not chunks:
52+
return [], []
53+
54+
if self.client is None:
55+
logger.warning("[VoyageReranker] Reranking skipped (no API key). Returning top_k unranked.")
56+
return chunks[:top_k], []
57+
58+
documents = [
59+
f"Page: {c['page_title']}\nSection: {c['section']}\n{c['text']}"
60+
if c.get("section")
61+
else f"Page: {c['page_title']}\n{c['text']}"
62+
for c in chunks
63+
]
64+
65+
try:
66+
result = self.client.rerank(
67+
query=query,
68+
documents=documents,
69+
model=self.model,
70+
top_k=len(chunks), # fetch all scores; we apply threshold + top_k ourselves
71+
)
72+
73+
scored_chunks: List[Dict[str, Any]] = []
74+
for item in result.results:
75+
chunk = chunks[item.index]
76+
scored_chunks.append({**chunk, "rerank_score": round(item.relevance_score, 4)})
77+
78+
scored_chunks.sort(key=lambda x: x["rerank_score"], reverse=True)
79+
80+
logger.info(
81+
f"[VoyageReranker] scores: "
82+
f"{[(c['page_title'], c['rerank_score']) for c in scored_chunks]}"
83+
)
84+
85+
above_threshold = [c for c in scored_chunks if c["rerank_score"] >= self.threshold]
86+
final_scored = above_threshold[:top_k]
87+
88+
if not final_scored:
89+
logger.warning(
90+
f"[VoyageReranker] All {len(chunks)} chunks below threshold "
91+
f"({self.threshold}). Top score: "
92+
f"{scored_chunks[0]['rerank_score'] if scored_chunks else 'N/A'}."
93+
)
94+
95+
# Strip rerank_score before passing to generation prompt
96+
final_chunks = [
97+
{k: v for k, v in c.items() if k != "rerank_score"} for c in final_scored
98+
]
99+
100+
return final_chunks, scored_chunks
101+
102+
except Exception as e:
103+
logger.warning(
104+
f"[VoyageReranker] Reranking failed, falling back to top {top_k} unranked: {e}"
105+
)
106+
return chunks[:top_k], []

src/backend/load/bm25_manager.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import logging
2+
import os
3+
import pickle
4+
5+
import bm25s
6+
7+
logger = logging.getLogger(__name__)
8+
9+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
10+
BM25_INDEX_PATH = os.path.join(PROJECT_ROOT, "workmate_db", "bm25_index.pkl")
11+
12+
13+
class BM25Manager:
14+
def __init__(self):
15+
self.index = None
16+
self.chunks: list[str] = []
17+
self.metadatas: list[dict] = []
18+
self.ids: list[str] = []
19+
20+
def build_index(self, chunks: list[str], metadatas: list[dict], ids: list[str]):
21+
self.chunks = chunks
22+
self.metadatas = metadatas
23+
self.ids = ids
24+
25+
corpus_indices = list(range(len(chunks)))
26+
indexed_texts = [
27+
f"{m.get('title', '')} {c}".lower()
28+
for c, m in zip(chunks, metadatas)
29+
]
30+
tokenized_corpus = bm25s.tokenize(indexed_texts)
31+
self.index = bm25s.BM25(corpus=corpus_indices)
32+
self.index.index(tokenized_corpus)
33+
logger.info(f"BM25 index built with {len(chunks)} documents")
34+
35+
def search(self, query: str, top_k: int = 10) -> list[dict]:
36+
if self.index is None:
37+
logger.warning("BM25 index not built, returning empty results")
38+
return []
39+
40+
k = min(top_k, len(self.chunks))
41+
query_tokens = bm25s.tokenize([query.lower()])
42+
results, _ = self.index.retrieve(query_tokens, k=k)
43+
44+
output = []
45+
for idx in results[0]:
46+
meta = self.metadatas[idx]
47+
output.append({
48+
"chunk_id": self.ids[idx],
49+
"text": self.chunks[idx],
50+
"page_title": meta.get("title", "Unknown Source"),
51+
"section": meta.get("parent_title", ""),
52+
**meta,
53+
})
54+
return output
55+
56+
def save(self, path: str = BM25_INDEX_PATH):
57+
os.makedirs(os.path.dirname(path), exist_ok=True)
58+
with open(path, "wb") as f:
59+
pickle.dump(
60+
{
61+
"index": self.index,
62+
"chunks": self.chunks,
63+
"metadatas": self.metadatas,
64+
"ids": self.ids,
65+
},
66+
f,
67+
)
68+
logger.info(f"BM25 index saved to {path}")
69+
70+
def load(self, path: str = BM25_INDEX_PATH):
71+
with open(path, "rb") as f:
72+
data = pickle.load(f)
73+
self.index = data["index"]
74+
self.chunks = data["chunks"]
75+
self.metadatas = data["metadatas"]
76+
self.ids = data["ids"]
77+
logger.info(f"BM25 index loaded from {path} ({len(self.chunks)} documents)")

0 commit comments

Comments
 (0)