Skip to content

Commit 3238da1

Browse files
committed
refactor: use pgvector ORM methods instead of raw SQL
1 parent 8f117fc commit 3238da1

2 files changed

Lines changed: 13 additions & 51 deletions

File tree

app/repositories/cached_query_repository.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,20 @@ def __init__(self, session: SyncSession):
1616
def find_similar_query(self, query_embedding: List[float]) -> CachedQuery | None:
1717
"""
1818
Finds a cached query if its embedding is within the similarity threshold.
19-
This is done in a single, efficient query.
19+
This is done in a single, efficient query using the ORM.
2020
"""
2121

2222
distance_threshold = 1 - SIMILARITY_THRESHOLD
2323

24-
stmt = text("""
25-
SELECT id
26-
FROM cached_queries
27-
WHERE (question_embedding <=> CAST(:query_embedding AS vector)) < :distance_threshold
28-
ORDER BY question_embedding <=> CAST(:query_embedding AS vector)
29-
LIMIT 1
30-
""")
24+
stmt = select(CachedQuery).options(
25+
selectinload(CachedQuery.source_chunks)
26+
).where(
27+
CachedQuery.question_embedding.cosine_distance(query_embedding) < distance_threshold
28+
).order_by(
29+
CachedQuery.question_embedding.cosine_distance(query_embedding)
30+
).limit(1)
3131

32-
result = self.session.execute(
33-
stmt,
34-
{
35-
"query_embedding": query_embedding,
36-
"distance_threshold": distance_threshold
37-
}
38-
).scalar_one_or_none()
39-
40-
if result:
41-
similar_query_id = result
42-
final_stmt = select(CachedQuery).options(
43-
selectinload(CachedQuery.source_chunks)
44-
).where(CachedQuery.id == similar_query_id)
45-
46-
return self.session.execute(final_stmt).scalar_one_or_none()
47-
48-
return None
32+
return self.session.execute(stmt).scalar_one_or_none()
4933

5034
def save_query(self, question_text: str, query_embedding: list[float], answer: str, source_chunks: list[Chunk]):
5135
"""Saves a new query and links it to its source chunks."""

app/repositories/chunk_repository.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,10 @@ def delete_by_document_id_sync(self, document_id: str):
1818
def find_relevant_chunks_sync(self, query_embedding: list[float], top_k: int = 5) -> list[Chunk]:
1919
"""
2020
Finds the most relevant chunks using vector similarity.
21-
22-
Strategy: Use raw SQL for fast vector search to get IDs only,
23-
then fetch via ORM to get properly-attached objects that can be used in relationships.
2421
"""
25-
26-
stmt = text("""
27-
SELECT id
28-
FROM chunks
29-
ORDER BY embedding <=> CAST(:query_embedding AS vector)
30-
LIMIT :top_k
31-
""")
32-
33-
result = self.session.execute(
34-
stmt,
35-
{"query_embedding": query_embedding.tolist() if hasattr(query_embedding, 'tolist') else query_embedding, "top_k": top_k}
36-
)
37-
38-
chunk_ids = [row[0] for row in result]
39-
40-
if not chunk_ids:
41-
return []
42-
43-
chunks = self.session.execute(
44-
select(Chunk).where(Chunk.id.in_(chunk_ids))
45-
).scalars().all()
22+
stmt = select(Chunk).order_by(
23+
Chunk.embedding.cosine_distance(query_embedding)
24+
).limit(top_k)
4625

47-
chunks_dict = {chunk.id: chunk for chunk in chunks}
48-
return [chunks_dict[chunk_id] for chunk_id in chunk_ids if chunk_id in chunks_dict]
26+
return self.session.execute(stmt).scalars().all()
4927

0 commit comments

Comments
 (0)