|
3 | 3 | import logging |
4 | 4 | from uuid import UUID |
5 | 5 | from datetime import datetime, timedelta |
6 | | -from sqlalchemy import select, and_, or_ |
| 6 | +from sqlalchemy import select, and_, or_, func |
7 | 7 | from sqlalchemy.ext.asyncio import AsyncSession |
8 | 8 | from sqlalchemy.orm import selectinload |
9 | 9 | from openai import AsyncOpenAI |
@@ -222,25 +222,32 @@ async def get_connections( |
222 | 222 | self, |
223 | 223 | user_id: UUID, |
224 | 224 | limit: int = 20, |
| 225 | + offset: int = 0, |
225 | 226 | unnotified_only: bool = False, |
226 | 227 | undismissed_only: bool = True, |
227 | | - ) -> list[MemoryConnection]: |
228 | | - """Get connections for a user.""" |
229 | | - query = ( |
| 228 | + ) -> tuple[list[MemoryConnection], int]: |
| 229 | + """Get connections for a user with pagination.""" |
| 230 | + # Base query for filtering |
| 231 | + base_query = ( |
230 | 232 | select(MemoryConnection) |
231 | 233 | .where(MemoryConnection.user_id == user_id) |
232 | | - .order_by(MemoryConnection.created_at.desc()) |
233 | 234 | ) |
234 | 235 |
|
235 | 236 | if unnotified_only: |
236 | | - query = query.where(MemoryConnection.notified_at.is_(None)) |
| 237 | + base_query = base_query.where(MemoryConnection.notified_at.is_(None)) |
237 | 238 |
|
238 | 239 | if undismissed_only: |
239 | | - query = query.where(MemoryConnection.dismissed_at.is_(None)) |
| 240 | + base_query = base_query.where(MemoryConnection.dismissed_at.is_(None)) |
| 241 | + |
| 242 | + # Get total count |
| 243 | + count_query = select(func.count()).select_from(base_query.subquery()) |
| 244 | + total_result = await self.db.execute(count_query) |
| 245 | + total = total_result.scalar() or 0 |
240 | 246 |
|
241 | | - query = query.limit(limit) |
| 247 | + # Apply ordering and pagination |
| 248 | + query = base_query.order_by(MemoryConnection.created_at.desc()).offset(offset).limit(limit) |
242 | 249 | result = await self.db.execute(query) |
243 | | - return list(result.scalars().all()) |
| 250 | + return list(result.scalars().all()), total |
244 | 251 |
|
245 | 252 | async def get_connection_with_memories( |
246 | 253 | self, connection_id: UUID, user_id: UUID |
|
0 commit comments