diff --git a/.jules/bolt.md b/.jules/bolt.md index b45f9caa..d781f1e4 100644 --- a/.jules/bolt.md +++ b/.jules/bolt.md @@ -60,3 +60,6 @@ Optimized metric route processing to O(N) by creating a mapping of routes direct ## 2026-06-25 - Avoid Map allocations in frontend ERD loops and mutate asyncpg records in-place **Learning:** The frontend `snapshotToGraph` iterates over thousands of columns to generate the graph, so repeated lookups and redundant collection assignments increase GC pressure. Backend snapshot column dictionaries are freshly instantiated for the payload, so `add_column_examples` can safely fill missing fields in place. **Action:** Reuse existing collections while aggregating relational data, create `Map`/`Set` entries only on first use, and check for missing example fields before calling expensive inference helpers. +## 2026-07-05 - Optimize Node Search Garbage Collection +**Learning:** High-frequency React Flow hooks (like searching through large sets of nodes/columns) can cause severe garbage collection pressure and rendering hitches if they use array allocation methods (`.flatMap()`, `.join()`) inside render loops. +**Action:** Replace array mapping and spreading with direct string concatenation (`+`) for string building inside frequent loops or React `useMemo` hooks to significantly reduce allocations and avoid GC jank. diff --git a/backend/app/api/share.py b/backend/app/api/share.py index 33cc8bde..18685235 100644 --- a/backend/app/api/share.py +++ b/backend/app/api/share.py @@ -17,6 +17,7 @@ SchemaSnapshotData, ShareLink, ) +from app.redact import redact_sensitive_schema_data from app.spec.llm import ( LlmConfigurationError, LlmProviderError, @@ -124,7 +125,7 @@ async def get_shared_snapshot( "status": snap.status, "schema_filter": snap.schema_filter, "error_message": snap.error_message, - "snapshot_json": data.snapshot_json if data else None, + "snapshot_json": redact_sensitive_schema_data(data.snapshot_json) if data else None, } @@ -154,7 +155,9 @@ async def export_shared_snapshot_sql( data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) if data is None: return "-- snapshot data not found\n" - return snapshot_json_to_sql(data.snapshot_json, target_dialect=dialect) + + redacted_json = redact_sensitive_schema_data(data.snapshot_json) + return snapshot_json_to_sql(redacted_json, target_dialect=dialect) @router.get( @@ -183,9 +186,11 @@ async def export_shared_snapshot_reversing_spec( data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) if data is None: return "# DB Reversing Specification\n\nSnapshot data not found.\n" + + redacted_json = redact_sensitive_schema_data(data.snapshot_json) if mode == "llm-draft": try: - return await generate_reversing_llm_draft(data.snapshot_json) + return await generate_reversing_llm_draft(redacted_json) except LlmConfigurationError as exc: raise HTTPException( status_code=503, detail="LLM configuration error" @@ -194,7 +199,7 @@ async def export_shared_snapshot_reversing_spec( raise HTTPException( status_code=502, detail="LLM provider request failed" ) from exc - return generate_reversing_spec(data.snapshot_json, mode=mode) + return generate_reversing_spec(redacted_json, mode=mode) @router.get( @@ -223,9 +228,11 @@ async def export_shared_snapshot_index_design( data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) if data is None: return "# ERD Index Design\n\nSnapshot data not found.\n" + + redacted_json = redact_sensitive_schema_data(data.snapshot_json) if mode == "llm-draft": try: - return await generate_index_design_llm_draft(data.snapshot_json) + return await generate_index_design_llm_draft(redacted_json) except LlmConfigurationError as exc: raise HTTPException( status_code=503, detail="LLM configuration error" @@ -234,4 +241,4 @@ async def export_shared_snapshot_index_design( raise HTTPException( status_code=502, detail="LLM provider request failed" ) from exc - return generate_index_design_spec(data.snapshot_json, mode=mode) + return generate_index_design_spec(redacted_json, mode=mode) diff --git a/backend/app/redact.py b/backend/app/redact.py new file mode 100644 index 00000000..7bbec6b6 --- /dev/null +++ b/backend/app/redact.py @@ -0,0 +1,24 @@ +import copy +from typing import Any, Dict + + +def redact_sensitive_schema_data(snapshot_json: Dict[str, Any] | None) -> Dict[str, Any]: + """Redacts sensitive properties like comments or example values from snapshot data.""" + if not snapshot_json: + return {} + + redacted = copy.deepcopy(snapshot_json) + + if "tables" in redacted and isinstance(redacted["tables"], list): + for table in redacted["tables"]: + if "comment" in table and table["comment"] is not None: + table["comment"] = "[REDACTED]" + + if "columns" in table and isinstance(table["columns"], list): + for column in table["columns"]: + if "column_comment" in column and column["column_comment"] is not None: + column["column_comment"] = "[REDACTED]" + if "example_value" in column and column["example_value"] is not None: + column["example_value"] = "[REDACTED]" + + return redacted diff --git a/backend/build/lib/app/__init__.py b/backend/build/lib/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/build/lib/app/api/__init__.py b/backend/build/lib/app/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/build/lib/app/api/auth_routes.py b/backend/build/lib/app/api/auth_routes.py new file mode 100644 index 00000000..1f0c8fc0 --- /dev/null +++ b/backend/build/lib/app/api/auth_routes.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from fastapi import APIRouter, Request + +from app.auth import revoke_current_request_token + +router = APIRouter(prefix="/api/auth", tags=["auth"]) + + +@router.post("/logout") +async def logout(request: Request) -> dict[str, bool]: + """Invalidate the current bearer token for this app process.""" + + await revoke_current_request_token(request) + return {"ok": True} diff --git a/backend/build/lib/app/api/connections.py b/backend/build/lib/app/api/connections.py new file mode 100644 index 00000000..5689ce62 --- /dev/null +++ b/backend/build/lib/app/api/connections.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import datetime as dt +import uuid + +from fastapi import APIRouter, Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import CurrentUser, get_current_user +from app.db import get_read_session, get_session +from app.models import DbConnection +from app.permissions import require_project_member +from app.schemas import ConnectionCreateIn, ConnectionOut +from app.security import encrypt_text +from app.sanitize import sanitize_for_storage + +router = APIRouter(prefix="/api/connections", tags=["connections"]) + + +@router.get("/by-project/{project_space_uuid}", response_model=list[ConnectionOut]) +async def list_connections( + project_space_uuid: uuid.UUID, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> list[ConnectionOut]: + """List DB connections for a project.""" + await require_project_member(session, project_space_uuid, user.user_account_uuid) + rows = await session.execute( + select(DbConnection) + .where(DbConnection.project_space_uuid == project_space_uuid) + .order_by(DbConnection.created_at.desc()) + ) + cons = rows.scalars().all() + return [ + ConnectionOut(db_connection_uuid=c.db_connection_uuid, conn_name=c.conn_name) + for c in cons + ] + + +@router.post("/by-project/{project_space_uuid}", response_model=ConnectionOut) +async def create_connection( + project_space_uuid: uuid.UUID, + body: ConnectionCreateIn, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +) -> ConnectionOut: + """Create a DB connection for a project (encrypt DSN at rest).""" + await require_project_member( + session, project_space_uuid, user.user_account_uuid, minimum_role="editor" + ) + encrypted = encrypt_text(str(sanitize_for_storage(body.dsn))) + c = DbConnection( + db_connection_uuid=uuid.uuid4(), + project_space_uuid=project_space_uuid, + conn_name=str(sanitize_for_storage(body.conn_name)), + dsn_ciphertext=encrypted.ciphertext, + dsn_nonce=encrypted.nonce, + created_at=dt.datetime.now(dt.timezone.utc), + updated_at=dt.datetime.now(dt.timezone.utc), + ) + session.add(c) + await session.commit() + return ConnectionOut(db_connection_uuid=c.db_connection_uuid, conn_name=c.conn_name) diff --git a/backend/build/lib/app/api/me.py b/backend/build/lib/app/api/me.py new file mode 100644 index 00000000..45810a8b --- /dev/null +++ b/backend/build/lib/app/api/me.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends + +from app.auth import CurrentUser, get_current_user +from app.schemas import MeOut + +router = APIRouter(prefix="/api", tags=["me"]) + + +@router.get("/me", response_model=MeOut) +async def get_me(user: CurrentUser = Depends(get_current_user)) -> MeOut: + """Return the current user's identity.""" + return MeOut( + user_account_uuid=user.user_account_uuid, + subject=user.subject, + display_name=user.display_name, + ) diff --git a/backend/build/lib/app/api/projects.py b/backend/build/lib/app/api/projects.py new file mode 100644 index 00000000..4d2a1c90 --- /dev/null +++ b/backend/build/lib/app/api/projects.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import datetime as dt +import uuid + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import CurrentUser, get_current_user +from app.db import get_read_session, get_session +from app.permissions import require_project_member +from app.models import ProjectMember, ProjectSpace, UserAccount +from app.schemas import ( + ProjectCreateIn, + ProjectMemberAddIn, + ProjectMemberOut, + ProjectOut, +) +from app.sanitize import sanitize_for_storage + +router = APIRouter(prefix="/api/projects", tags=["projects"]) + + +@router.get("", response_model=list[ProjectOut]) +async def list_projects( + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> list[ProjectOut]: + """List projects that the current user is a member of.""" + rows = await session.execute( + select(ProjectSpace) + .join( + ProjectMember, + ProjectMember.project_space_uuid == ProjectSpace.project_space_uuid, + ) + .where(ProjectMember.user_account_uuid == user.user_account_uuid) + .order_by(ProjectSpace.created_at.desc()) + ) + projects = rows.scalars().all() + return [ + ProjectOut(project_space_uuid=p.project_space_uuid, project_name=p.project_name) + for p in projects + ] + + +@router.post("", response_model=ProjectOut) +async def create_project( + body: ProjectCreateIn, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +) -> ProjectOut: + """Create a new project and add the creator as the owner.""" + p = ProjectSpace( + project_space_uuid=uuid.uuid4(), + project_name=str(sanitize_for_storage(body.project_name)), + created_by_user_uuid=user.user_account_uuid, + created_at=dt.datetime.now(dt.timezone.utc), + ) + session.add(p) + await session.flush() # ensure project_space row exists before FK insert + + m = ProjectMember( + project_space_uuid=p.project_space_uuid, + user_account_uuid=user.user_account_uuid, + project_role="owner", + created_at=dt.datetime.now(dt.timezone.utc), + ) + session.add(m) + await session.commit() + return ProjectOut( + project_space_uuid=p.project_space_uuid, project_name=p.project_name + ) + + +@router.get("/{project_space_uuid}/members", response_model=list[ProjectMemberOut]) +async def list_project_members( + project_space_uuid: uuid.UUID, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> list[ProjectMemberOut]: + """List members of a project (MVP: any member can view).""" + # Remediation for IDOR: Only owners or editors can view all members. + await require_project_member( + session, project_space_uuid, user.user_account_uuid, minimum_role="editor" + ) + + rows = await session.execute( + select(ProjectMember, UserAccount) + .join( + UserAccount, + UserAccount.user_account_uuid == ProjectMember.user_account_uuid, + ) + .where(ProjectMember.project_space_uuid == project_space_uuid) + .order_by(ProjectMember.created_at.asc()) + ) + out: list[ProjectMemberOut] = [] + for m, u in rows.all(): + out.append( + ProjectMemberOut( + user_account_uuid=u.user_account_uuid, + member_subject=u.oidc_subject, + project_role=m.project_role, + ) + ) + return out + + +async def _ensure_owner( + session: AsyncSession, project_space_uuid: uuid.UUID, user_account_uuid: uuid.UUID +) -> None: + row = await session.execute( + select(ProjectMember.project_role).where( + ProjectMember.project_space_uuid == project_space_uuid, + ProjectMember.user_account_uuid == user_account_uuid, + ) + ) + role = row.scalar_one_or_none() + if role != "owner": + raise HTTPException(status_code=403, detail="owner role required") + + +async def _ensure_user_exists(session: AsyncSession, subject: str) -> UserAccount: + row2 = await session.execute( + select(UserAccount).where(UserAccount.oidc_subject == subject) + ) + u = row2.scalars().first() + if u is None: + u = UserAccount( + user_account_uuid=uuid.uuid4(), + oidc_subject=subject, + display_name=None, + created_at=dt.datetime.now(dt.timezone.utc), + ) + session.add(u) + await session.flush() + return u + + +async def _ensure_not_changing_owner_role( + session: AsyncSession, project_space_uuid: uuid.UUID, user_account_uuid: uuid.UUID +) -> None: + row3 = await session.execute( + select(ProjectMember.project_role).where( + ProjectMember.project_space_uuid == project_space_uuid, + ProjectMember.user_account_uuid == user_account_uuid, + ) + ) + existing_role = row3.scalar_one_or_none() + if existing_role == "owner": + raise HTTPException( + status_code=400, + detail="cannot change owner role via invite endpoint", + ) + + +async def _upsert_project_member( + session: AsyncSession, + project_space_uuid: uuid.UUID, + user_account_uuid: uuid.UUID, + project_role: str, +) -> str: + stmt = ( + insert(ProjectMember) + .values( + project_space_uuid=project_space_uuid, + user_account_uuid=user_account_uuid, + project_role=project_role, + created_at=dt.datetime.now(dt.timezone.utc), + ) + .on_conflict_do_update( + index_elements=[ + ProjectMember.project_space_uuid, + ProjectMember.user_account_uuid, + ], + set_={"project_role": project_role}, + ) + .returning(ProjectMember.project_role) + ) + new_role = (await session.execute(stmt)).scalar_one() + await session.commit() + return str(new_role) + + +@router.post("/{project_space_uuid}/members", response_model=ProjectMemberOut) +async def add_project_member( + project_space_uuid: uuid.UUID, + body: ProjectMemberAddIn, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +) -> ProjectMemberOut: + """Invite/add a project member (owner-only). + + Uses a Postgres upsert to make the operation idempotent and race-safe. + """ + await _ensure_owner(session, project_space_uuid, user.user_account_uuid) + + subject = str(sanitize_for_storage(body.member_subject)).strip() + if not subject: + raise HTTPException(status_code=400, detail="member_subject required") + + u = await _ensure_user_exists(session, subject) + await _ensure_not_changing_owner_role( + session, project_space_uuid, u.user_account_uuid + ) + + new_role = await _upsert_project_member( + session, project_space_uuid, u.user_account_uuid, body.project_role + ) + + return ProjectMemberOut( + user_account_uuid=u.user_account_uuid, + member_subject=u.oidc_subject, + project_role=new_role, + ) diff --git a/backend/build/lib/app/api/share.py b/backend/build/lib/app/api/share.py new file mode 100644 index 00000000..18685235 --- /dev/null +++ b/backend/build/lib/app/api/share.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import datetime as dt +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import PlainTextResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import CurrentUser, get_current_user +from app.db import get_read_session, get_session +from app.ddl.export import snapshot_json_to_sql +from app.models import ( + ProjectMember, + SchemaSnapshot, + SchemaSnapshotData, + ShareLink, +) +from app.redact import redact_sensitive_schema_data +from app.spec.llm import ( + LlmConfigurationError, + LlmProviderError, + generate_index_design_llm_draft, + generate_reversing_llm_draft, +) +from app.spec.index_design import generate_index_design_spec +from app.spec.reversing import generate_reversing_spec + +router = APIRouter(prefix="/api", tags=["share"]) + + +@router.post("/projects/{project_space_uuid}/share-links") +async def create_share_link( + project_space_uuid: uuid.UUID, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +) -> dict: + """Create a share link for a project (owner-only).""" + # owner only + row = await session.execute( + select(ProjectMember.project_role).where( + ProjectMember.project_space_uuid == project_space_uuid, + ProjectMember.user_account_uuid == user.user_account_uuid, + ) + ) + if row.scalar_one_or_none() != "owner": + raise HTTPException(status_code=403, detail="owner role required") + + link = ShareLink( + share_link_uuid=uuid.uuid4(), + project_space_uuid=project_space_uuid, + created_by_user_uuid=user.user_account_uuid, + permission_kind="viewer", + expires_at=None, + created_at=dt.datetime.now(dt.timezone.utc), + ) + session.add(link) + await session.commit() + return { + "share_link_uuid": str(link.share_link_uuid), + "permission_kind": link.permission_kind, + "url_path": f"/api/share/{link.share_link_uuid}", + } + + +@router.get("/share/{share_link_uuid}") +async def get_share_link_info( + share_link_uuid: uuid.UUID, + session: AsyncSession = Depends(get_read_session), +) -> dict: + """Return share link metadata and recent snapshots.""" + link = await session.get(ShareLink, share_link_uuid) + if link is None: + raise HTTPException(status_code=404, detail="share link not found") + if link.expires_at is not None and link.expires_at <= dt.datetime.now( + dt.timezone.utc + ): + raise HTTPException(status_code=410, detail="share link expired") + + rows = await session.execute( + select(SchemaSnapshot) + .where(SchemaSnapshot.project_space_uuid == link.project_space_uuid) + .order_by(SchemaSnapshot.created_at.desc()) + .limit(20) + ) + snaps = rows.scalars().all() + return { + "project_space_uuid": str(link.project_space_uuid), + "permission_kind": link.permission_kind, + "snapshots": [ + { + "schema_snapshot_uuid": str(s.schema_snapshot_uuid), + "status": s.status, + "schema_filter": s.schema_filter, + "created_at": s.created_at.isoformat(), + } + for s in snaps + ], + } + + +@router.get("/share/{share_link_uuid}/snapshots/{schema_snapshot_uuid}") +async def get_shared_snapshot( + share_link_uuid: uuid.UUID, + schema_snapshot_uuid: uuid.UUID, + session: AsyncSession = Depends(get_read_session), +) -> dict: + """Return a snapshot via a share link (no auth).""" + link = await session.get(ShareLink, share_link_uuid) + if link is None: + raise HTTPException(status_code=404, detail="share link not found") + if link.expires_at is not None and link.expires_at <= dt.datetime.now( + dt.timezone.utc + ): + raise HTTPException(status_code=410, detail="share link expired") + + snap = await session.get(SchemaSnapshot, schema_snapshot_uuid) + if snap is None or snap.project_space_uuid != link.project_space_uuid: + raise HTTPException(status_code=404, detail="snapshot not found") + + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + return { + "schema_snapshot_uuid": str(snap.schema_snapshot_uuid), + "status": snap.status, + "schema_filter": snap.schema_filter, + "error_message": snap.error_message, + "snapshot_json": redact_sensitive_schema_data(data.snapshot_json) if data else None, + } + + +@router.get( + "/share/{share_link_uuid}/snapshots/{schema_snapshot_uuid}/export.sql", + response_class=PlainTextResponse, +) +async def export_shared_snapshot_sql( + share_link_uuid: uuid.UUID, + schema_snapshot_uuid: uuid.UUID, + dialect: str = Query("postgresql", pattern="^(postgresql|snowflake)$"), + session: AsyncSession = Depends(get_read_session), +) -> str: + """Export a shared snapshot as SQL via a share link.""" + link = await session.get(ShareLink, share_link_uuid) + if link is None: + raise HTTPException(status_code=404, detail="share link not found") + if link.expires_at is not None and link.expires_at <= dt.datetime.now( + dt.timezone.utc + ): + raise HTTPException(status_code=410, detail="share link expired") + + snap = await session.get(SchemaSnapshot, schema_snapshot_uuid) + if snap is None or snap.project_space_uuid != link.project_space_uuid: + raise HTTPException(status_code=404, detail="snapshot not found") + + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + if data is None: + return "-- snapshot data not found\n" + + redacted_json = redact_sensitive_schema_data(data.snapshot_json) + return snapshot_json_to_sql(redacted_json, target_dialect=dialect) + + +@router.get( + "/share/{share_link_uuid}/snapshots/{schema_snapshot_uuid}/reversing-spec.md", + response_class=PlainTextResponse, +) +async def export_shared_snapshot_reversing_spec( + share_link_uuid: uuid.UUID, + schema_snapshot_uuid: uuid.UUID, + mode: str = Query("markdown", pattern="^(markdown|llm-prompt|llm-draft)$"), + session: AsyncSession = Depends(get_read_session), +) -> str: + """Export a shared snapshot as a DB reversing spec or LLM prompt.""" + link = await session.get(ShareLink, share_link_uuid) + if link is None: + raise HTTPException(status_code=404, detail="share link not found") + if link.expires_at is not None and link.expires_at <= dt.datetime.now( + dt.timezone.utc + ): + raise HTTPException(status_code=410, detail="share link expired") + + snap = await session.get(SchemaSnapshot, schema_snapshot_uuid) + if snap is None or snap.project_space_uuid != link.project_space_uuid: + raise HTTPException(status_code=404, detail="snapshot not found") + + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + if data is None: + return "# DB Reversing Specification\n\nSnapshot data not found.\n" + + redacted_json = redact_sensitive_schema_data(data.snapshot_json) + if mode == "llm-draft": + try: + return await generate_reversing_llm_draft(redacted_json) + except LlmConfigurationError as exc: + raise HTTPException( + status_code=503, detail="LLM configuration error" + ) from exc + except LlmProviderError as exc: + raise HTTPException( + status_code=502, detail="LLM provider request failed" + ) from exc + return generate_reversing_spec(redacted_json, mode=mode) + + +@router.get( + "/share/{share_link_uuid}/snapshots/{schema_snapshot_uuid}/index-design.md", + response_class=PlainTextResponse, +) +async def export_shared_snapshot_index_design( + share_link_uuid: uuid.UUID, + schema_snapshot_uuid: uuid.UUID, + mode: str = Query("markdown", pattern="^(markdown|llm-prompt|llm-draft)$"), + session: AsyncSession = Depends(get_read_session), +) -> str: + """Export shared table/index design guidance or an LLM prompt.""" + link = await session.get(ShareLink, share_link_uuid) + if link is None: + raise HTTPException(status_code=404, detail="share link not found") + if link.expires_at is not None and link.expires_at <= dt.datetime.now( + dt.timezone.utc + ): + raise HTTPException(status_code=410, detail="share link expired") + + snap = await session.get(SchemaSnapshot, schema_snapshot_uuid) + if snap is None or snap.project_space_uuid != link.project_space_uuid: + raise HTTPException(status_code=404, detail="snapshot not found") + + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + if data is None: + return "# ERD Index Design\n\nSnapshot data not found.\n" + + redacted_json = redact_sensitive_schema_data(data.snapshot_json) + if mode == "llm-draft": + try: + return await generate_index_design_llm_draft(redacted_json) + except LlmConfigurationError as exc: + raise HTTPException( + status_code=503, detail="LLM configuration error" + ) from exc + except LlmProviderError as exc: + raise HTTPException( + status_code=502, detail="LLM provider request failed" + ) from exc + return generate_index_design_spec(redacted_json, mode=mode) diff --git a/backend/build/lib/app/api/snapshots.py b/backend/build/lib/app/api/snapshots.py new file mode 100644 index 00000000..79f2515f --- /dev/null +++ b/backend/build/lib/app/api/snapshots.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import datetime as dt +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import PlainTextResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth import CurrentUser, get_current_user +from app.db import get_read_session, get_session +from app.models import ( + DbConnection, + JobQueue, + SchemaSnapshot, + SchemaSnapshotData, +) +from app.permissions import require_project_member +from app.schemas import SnapshotCreateIn, SnapshotDetailOut, SnapshotOut +from app.ddl.export import snapshot_json_to_sql +from app.jobs.valkey_queue import enqueue_job_signal +from app.spec.llm import ( + LlmConfigurationError, + LlmProviderError, + generate_index_design_llm_draft, + generate_reversing_llm_draft, +) +from app.spec.index_design import generate_index_design_spec +from app.spec.reversing import generate_reversing_spec + +router = APIRouter(prefix="/api/snapshots", tags=["snapshots"]) + + +def _snapshot_not_found(schema_snapshot_uuid: uuid.UUID) -> SnapshotDetailOut: + """Return the uniform snapshot-not-found response.""" + + return SnapshotDetailOut( + schema_snapshot_uuid=schema_snapshot_uuid, + status="not_found", + schema_filter=None, + error_message="snapshot not found", + snapshot_json=None, + ) + + +async def _get_authorized_snapshot( + session: AsyncSession, + schema_snapshot_uuid: uuid.UUID, + user: CurrentUser, +) -> SchemaSnapshot | None: + """Fetch a snapshot only after project membership has been checked.""" + + project_space_uuid = await session.scalar( + select(SchemaSnapshot.project_space_uuid).where( + SchemaSnapshot.schema_snapshot_uuid == schema_snapshot_uuid + ) + ) + if project_space_uuid is None: + return None + + try: + await require_project_member( + session, project_space_uuid, user.user_account_uuid + ) + except HTTPException as exc: + if exc.status_code == 403: + return None + raise + + return await session.get(SchemaSnapshot, schema_snapshot_uuid) + + +@router.post("/by-project/{project_space_uuid}", response_model=SnapshotOut) +async def create_snapshot( + project_space_uuid: uuid.UUID, + body: SnapshotCreateIn, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +) -> SnapshotOut: + """Create a schema snapshot job for a project connection.""" + await require_project_member( + session, project_space_uuid, user.user_account_uuid, minimum_role="editor" + ) + + # Ensure connection belongs to this project + conn = await session.get(DbConnection, body.db_connection_uuid) + if conn is None or conn.project_space_uuid != project_space_uuid: + raise HTTPException(status_code=404, detail="connection not found") + + snap = SchemaSnapshot( + schema_snapshot_uuid=uuid.uuid4(), + project_space_uuid=project_space_uuid, + db_connection_uuid=body.db_connection_uuid, + status="queued", + schema_filter=body.schema_filter, + started_at=None, + finished_at=None, + error_message=None, + created_at=dt.datetime.now(dt.timezone.utc), + ) + session.add(snap) + + job = JobQueue( + job_queue_uuid=uuid.uuid4(), + job_type="snapshot", + status="queued", + payload_json={"schema_snapshot_uuid": str(snap.schema_snapshot_uuid)}, + run_after=dt.datetime.now(dt.timezone.utc), + attempt_count=0, + last_error=None, + created_at=dt.datetime.now(dt.timezone.utc), + started_at=None, + finished_at=None, + ) + session.add(job) + + await session.commit() + await enqueue_job_signal(job.job_queue_uuid, job.run_after) + return SnapshotOut( + schema_snapshot_uuid=snap.schema_snapshot_uuid, + status=snap.status, + schema_filter=snap.schema_filter, + ) + + +@router.get("/{schema_snapshot_uuid}", response_model=SnapshotDetailOut) +async def get_snapshot( + schema_snapshot_uuid: uuid.UUID, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> SnapshotDetailOut: + """Get a snapshot's status and (if present) captured JSON.""" + snap = await _get_authorized_snapshot(session, schema_snapshot_uuid, user) + if snap is None: + return _snapshot_not_found(schema_snapshot_uuid) + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + return SnapshotDetailOut( + schema_snapshot_uuid=snap.schema_snapshot_uuid, + status=snap.status, + schema_filter=snap.schema_filter, + error_message=snap.error_message, + snapshot_json=data.snapshot_json if data else None, + ) + + +@router.get("/{schema_snapshot_uuid}/export.sql", response_class=PlainTextResponse) +async def export_snapshot_sql( + schema_snapshot_uuid: uuid.UUID, + dialect: str = Query("postgresql", pattern="^(postgresql|snowflake)$"), + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> str: + """Export a snapshot as dialect-specific SQL DDL (best-effort).""" + snap = await _get_authorized_snapshot(session, schema_snapshot_uuid, user) + if snap is None: + return "-- snapshot not found\n" + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + if data is None: + return "-- snapshot data not found\n" + return snapshot_json_to_sql(data.snapshot_json, target_dialect=dialect) + + +@router.get( + "/{schema_snapshot_uuid}/reversing-spec.md", + response_class=PlainTextResponse, +) +async def export_snapshot_reversing_spec( + schema_snapshot_uuid: uuid.UUID, + mode: str = Query("markdown", pattern="^(markdown|llm-prompt|llm-draft)$"), + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> str: + """Export a snapshot as a DB reversing spec or LLM prompt.""" + snap = await _get_authorized_snapshot(session, schema_snapshot_uuid, user) + if snap is None: + return "# DB Reversing Specification\n\nSnapshot not found.\n" + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + if data is None: + return "# DB Reversing Specification\n\nSnapshot data not found.\n" + if mode == "llm-draft": + try: + return await generate_reversing_llm_draft(data.snapshot_json) + except LlmConfigurationError as exc: + raise HTTPException( + status_code=503, detail="LLM configuration error" + ) from exc + except LlmProviderError as exc: + raise HTTPException( + status_code=502, detail="LLM provider request failed" + ) from exc + return generate_reversing_spec(data.snapshot_json, mode=mode) + + +@router.get( + "/{schema_snapshot_uuid}/index-design.md", + response_class=PlainTextResponse, +) +async def export_snapshot_index_design( + schema_snapshot_uuid: uuid.UUID, + mode: str = Query("markdown", pattern="^(markdown|llm-prompt|llm-draft)$"), + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> str: + """Export table/index design guidance or an LLM prompt.""" + snap = await _get_authorized_snapshot(session, schema_snapshot_uuid, user) + if snap is None: + return "# ERD Index Design\n\nSnapshot not found.\n" + data = await session.get(SchemaSnapshotData, schema_snapshot_uuid) + if data is None: + return "# ERD Index Design\n\nSnapshot data not found.\n" + if mode == "llm-draft": + try: + return await generate_index_design_llm_draft(data.snapshot_json) + except LlmConfigurationError as exc: + raise HTTPException( + status_code=503, detail="LLM configuration error" + ) from exc + except LlmProviderError as exc: + raise HTTPException( + status_code=502, detail="LLM provider request failed" + ) from exc + return generate_index_design_spec(data.snapshot_json, mode=mode) + + +@router.get("/by-project/{project_space_uuid}", response_model=list[SnapshotOut]) +async def list_snapshots( + project_space_uuid: uuid.UUID, + user: CurrentUser = Depends(get_current_user), + session: AsyncSession = Depends(get_read_session), +) -> list[SnapshotOut]: + """List snapshots for a project.""" + await require_project_member(session, project_space_uuid, user.user_account_uuid) + rows = await session.execute( + select(SchemaSnapshot) + .where(SchemaSnapshot.project_space_uuid == project_space_uuid) + .order_by(SchemaSnapshot.created_at.desc()) + ) + snaps = rows.scalars().all() + return [ + SnapshotOut( + schema_snapshot_uuid=s.schema_snapshot_uuid, + status=s.status, + schema_filter=s.schema_filter, + ) + for s in snaps + ] diff --git a/backend/build/lib/app/auth.py b/backend/build/lib/app/auth.py new file mode 100644 index 00000000..132e57f3 --- /dev/null +++ b/backend/build/lib/app/auth.py @@ -0,0 +1,423 @@ +from __future__ import annotations + +import asyncio +import datetime as dt +import uuid +from dataclasses import dataclass +from typing import Any, cast + +import httpx +from fastapi import Depends, HTTPException, Request +from jose import jwt +from sqlalchemy import select, delete +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_session +from app.models import UserAccount +from app.settings import settings + + +def _parse_oidc_algorithms(raw: str) -> list[str]: + """Parse OIDC_ALGORITHMS into a non-empty allowlist. + + Security note: + - JWT verification must *not* trust the token header's `alg`. + - We pass an explicit allowlist to the verifier. + """ + + # Normalize and deduplicate so env values like "rs256, RS256" behave + # predictably. + normalized: list[str] = [] + seen: set[str] = set() + for part in raw.split(","): + alg = part.strip().upper() + if not alg: + continue + # Defensive: never allow unsigned tokens. + if alg == "NONE": + continue + # Defensive: never allow symmetric algorithms to prevent public key HMAC forgery. + if alg.startswith("HS"): + continue + if alg in seen: + continue + seen.add(alg) + normalized.append(alg) + + return normalized or ["RS256"] + + +@dataclass(frozen=True) +class CurrentUser: + """Authenticated user identity used by API handlers.""" + + user_account_uuid: uuid.UUID + subject: str + display_name: str | None + + +@dataclass(frozen=True) +class VerifiedToken: + """Verified OIDC token details needed by auth and logout flows.""" + + subject: str + display_name: str | None + jwt_id: str + expires_at: dt.datetime + + +_oidc_config: dict[str, Any] | None = None +_oidc_jwks: dict[str, Any] | None = None +_oidc_config_expires_at: dt.datetime = dt.datetime.fromtimestamp(0, tz=dt.timezone.utc) +_oidc_jwks_expires_at: dt.datetime = dt.datetime.fromtimestamp(0, tz=dt.timezone.utc) +OIDC_ALLOWED_ALGORITHMS = tuple(_parse_oidc_algorithms(settings.oidc_algorithms)) +OIDC_CONFIG_CACHE_TTL = dt.timedelta(minutes=10) +OIDC_JWKS_CACHE_TTL = dt.timedelta(minutes=5) +OIDC_JWKS_MIN_REFRESH_INTERVAL = dt.timedelta(seconds=60) +_last_jwks_refresh_at: dt.datetime = dt.datetime.fromtimestamp(0, tz=dt.timezone.utc) +_jwks_lock = asyncio.Lock() +OIDC_JWT_LEEWAY_SECONDS = 60 +OIDC_ALLOWED_TOKEN_TYPES = {"jwt", "at+jwt"} + + +async def _get_oidc_config() -> dict: + """Fetch and cache the OIDC discovery document.""" + if not settings.oidc_issuer: + raise RuntimeError("OIDC is disabled") + + global _oidc_config, _oidc_config_expires_at + now = dt.datetime.now(dt.timezone.utc) + if _oidc_config is not None and now < _oidc_config_expires_at: + return cast(dict, _oidc_config) + + async with httpx.AsyncClient(timeout=5, follow_redirects=False) as client: + r = await client.get( + f"{settings.oidc_issuer.rstrip('/')}/.well-known/openid-configuration" + ) + if r.is_redirect: + raise RuntimeError("OIDC configuration endpoint must not redirect") + r.raise_for_status() + config = cast(dict[str, Any], r.json()) + + _oidc_config = config + _oidc_config_expires_at = now + OIDC_CONFIG_CACHE_TTL + return cast(dict, config) + + +async def _get_jwks(force_refresh: bool = False) -> dict: + """Fetch and cache the OIDC JWKS (signing keys).""" + config = await _get_oidc_config() + jwks_uri = config.get("jwks_uri") + if not isinstance(jwks_uri, str): + raise RuntimeError("OIDC jwks_uri missing") + + now = dt.datetime.now(dt.timezone.utc) + + if _oidc_jwks is not None: + if not force_refresh and now < _oidc_jwks_expires_at: + return cast(dict, _oidc_jwks) + if ( + force_refresh + and now < _last_jwks_refresh_at + OIDC_JWKS_MIN_REFRESH_INTERVAL + ): + return cast(dict, _oidc_jwks) + + async with _jwks_lock: + now = dt.datetime.now(dt.timezone.utc) + if _oidc_jwks is not None: + if not force_refresh and now < _oidc_jwks_expires_at: + return cast(dict, _oidc_jwks) + if ( + force_refresh + and now < _last_jwks_refresh_at + OIDC_JWKS_MIN_REFRESH_INTERVAL + ): + return cast(dict, _oidc_jwks) + + async with httpx.AsyncClient(timeout=5, follow_redirects=False) as client: + r = await client.get(jwks_uri) + if r.is_redirect: + raise RuntimeError("OIDC JWKS endpoint must not redirect") + r.raise_for_status() + jwks = cast(dict[str, Any], r.json()) + + refreshed_at = dt.datetime.now(dt.timezone.utc) + globals()["_oidc_jwks"] = jwks + globals()["_oidc_jwks_expires_at"] = refreshed_at + OIDC_JWKS_CACHE_TTL + globals()["_last_jwks_refresh_at"] = refreshed_at + return cast(dict, jwks) + + +def _pick_jwk(jwks: dict, kid: str | None) -> dict | None: + """Pick a JWK from a JWKS set by kid (or first if kid is None).""" + keys = jwks.get("keys") + if not isinstance(keys, list): + return None + for k in keys: + if not isinstance(k, dict): + continue + if kid is None or k.get("kid") == kid: + return k + return None + + +def _jwt_expiry(claims: dict[str, Any]) -> dt.datetime: + """Return the JWT expiry as an aware UTC datetime.""" + + exp = claims.get("exp") + if not isinstance(exp, int | float): + raise HTTPException(status_code=401, detail="token missing exp") + return dt.datetime.fromtimestamp(float(exp), tz=dt.timezone.utc) + + +def _validate_jwt_header(header: dict[str, Any]) -> str: + """Validate JOSE header fields before signature verification.""" + + token_type = header.get("typ") + if token_type is not None: + if ( + not isinstance(token_type, str) + or token_type.strip().lower() not in OIDC_ALLOWED_TOKEN_TYPES + ): + raise HTTPException(status_code=401, detail="unsupported token type") + + content_type = header.get("cty") + if content_type is not None: + raise HTTPException(status_code=401, detail="unsupported token content type") + + header_alg_raw = header.get("alg") + if not isinstance(header_alg_raw, str) or not header_alg_raw: + raise HTTPException(status_code=401, detail="token missing alg") + return header_alg_raw.upper() + + +async def revoke_token_jti(jwt_id: str, expires_at: dt.datetime) -> None: + """Record a JWT ID as revoked until its natural expiry.""" + + from app.models import RevokedToken + from app.db import SessionLocal + + if not jwt_id: + return + + current = dt.datetime.now(dt.timezone.utc) + async with SessionLocal() as session: + await session.execute( + delete(RevokedToken).where(RevokedToken.expires_at <= current) + ) + revoked = RevokedToken(jwt_id=jwt_id, expires_at=expires_at) + session.add(revoked) + await session.commit() + + +async def is_token_jti_revoked(jwt_id: str) -> bool: + """Return whether the JWT ID is currently revoked.""" + + from app.models import RevokedToken + from app.db import SessionLocal + + current = dt.datetime.now(dt.timezone.utc) + async with SessionLocal() as session: + stmt = select(RevokedToken).where( + RevokedToken.jwt_id == jwt_id, RevokedToken.expires_at > current + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() is not None + + +def _bearer_token_from_request(request: Request) -> str: + """Return the bearer token from a request or fail authentication.""" + + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + raise HTTPException(status_code=401, detail="missing bearer token") + return auth.split(" ", 1)[1].strip() + + +async def _decode_verified_oidc_token(token: str) -> dict[str, Any]: + """Validate the JOSE header, signing key, and claims signature.""" + + try: + header = cast(dict[str, Any], jwt.get_unverified_header(token)) + except Exception: # noqa: BLE001 + raise HTTPException(status_code=401, detail="invalid token header") + + header_alg = _validate_jwt_header(header) + if header_alg not in OIDC_ALLOWED_ALGORITHMS: + raise HTTPException( + status_code=401, + detail="unsupported token algorithm", + ) + + jwks = await _get_jwks() + jwk = _pick_jwk(jwks, header.get("kid")) + if jwk is None: + jwks = await _get_jwks(force_refresh=True) + jwk = _pick_jwk(jwks, header.get("kid")) + if jwk is None: + raise HTTPException(status_code=401, detail="unknown signing key") + + kty = jwk.get("kty") + if not isinstance(kty, str): + raise HTTPException(status_code=401, detail="algorithm/key type mismatch") + jwk_kty = kty.upper() + if jwk_kty == "RSA": + if not (header_alg.startswith("RS") or header_alg.startswith("PS")): + raise HTTPException(status_code=401, detail="algorithm/key type mismatch") + elif jwk_kty == "EC": + if not header_alg.startswith("ES"): + raise HTTPException(status_code=401, detail="algorithm/key type mismatch") + else: + raise HTTPException(status_code=401, detail="algorithm/key type mismatch") + + try: + claims = jwt.decode( + token, + jwk, + algorithms=list(OIDC_ALLOWED_ALGORITHMS), + audience=settings.oidc_audience, + issuer=settings.oidc_issuer, + options={ + "verify_aud": bool(settings.oidc_audience), + "require_aud": bool(settings.oidc_audience), + "require_iss": True, + "require_exp": True, + "require_jti": True, + "leeway": OIDC_JWT_LEEWAY_SECONDS, + }, + ) + except Exception as err: + raise HTTPException( + status_code=401, detail="token verification failed" + ) from err + + return cast(dict[str, Any], claims) + + +async def _verified_token_from_claims( + claims: dict[str, Any], verify_revocation: bool = True +) -> VerifiedToken: + """Validate decoded claims and return the request auth identity.""" + + sub = claims.get("sub") + jwt_id = claims.get("jti") + name = claims.get("name") or claims.get("preferred_username") + if not isinstance(sub, str): + raise HTTPException(status_code=401, detail="token missing sub") + if not isinstance(jwt_id, str) or not jwt_id.strip(): + raise HTTPException(status_code=401, detail="token missing jti") + + expires_at = _jwt_expiry(claims) + if verify_revocation and await is_token_jti_revoked(jwt_id): + raise HTTPException(status_code=401, detail="token revoked") + + return VerifiedToken( + subject=sub, + display_name=str(name) if isinstance(name, str) else None, + jwt_id=jwt_id, + expires_at=expires_at, + ) + + +async def _get_verified_token_from_request( + request: Request, verify_revocation: bool = True +) -> VerifiedToken: + """Extract and verify OIDC token claims from a request.""" + + if settings.oidc_issuer: + token = _bearer_token_from_request(request) + claims = await _decode_verified_oidc_token(token) + return await _verified_token_from_claims(claims, verify_revocation) + + raise HTTPException(status_code=500, detail="OIDC configuration required") + + +async def _get_subject_from_request( + request: Request, verify_revocation: bool = True +) -> tuple[str, str | None]: + """Extract (subject, display_name) from a verified request token.""" + + verified = await _get_verified_token_from_request(request, verify_revocation) + return verified.subject, verified.display_name + + +async def try_get_subject_for_rate_limit(request: Request) -> str | None: + """Best-effort subject extraction for rate limiting. + + This helper is intentionally lightweight: + - It must NOT touch the DB (unlike get_current_user). + - It must NOT change auth behavior. Missing/invalid auth returns None so + unauthenticated requests can still be limited by IP. + """ + + try: + subject, _ = await _get_subject_from_request(request, verify_revocation=False) + return subject + except HTTPException: + return None + + +_user_cache: dict[str, tuple[CurrentUser, dt.datetime]] = {} +USER_CACHE_MAX_SIZE = 1000 +USER_CACHE_TTL = dt.timedelta(minutes=5) + + +async def _ensure_user( + session: AsyncSession, subject: str, display_name: str | None +) -> CurrentUser: + """Get or create a UserAccount for the given OIDC subject.""" + now = dt.datetime.now(dt.timezone.utc) + cached = _user_cache.get(subject) + if cached is not None: + user, expires_at = cached + if now < expires_at: + return user + else: + del _user_cache[subject] + + row = await session.execute( + select(UserAccount).where(UserAccount.oidc_subject == subject) + ) + existing = row.scalars().first() + if existing is not None: + user = CurrentUser( + user_account_uuid=existing.user_account_uuid, + subject=existing.oidc_subject, + display_name=existing.display_name, + ) + else: + user_account = UserAccount( + user_account_uuid=uuid.uuid4(), + oidc_subject=subject, + display_name=display_name, + created_at=now, + ) + session.add(user_account) + await session.flush() + user = CurrentUser( + user_account_uuid=user_account.user_account_uuid, + subject=user_account.oidc_subject, + display_name=user_account.display_name, + ) + + if len(_user_cache) >= USER_CACHE_MAX_SIZE: + _user_cache.clear() + + _user_cache[subject] = (user, now + USER_CACHE_TTL) + return user + + +async def get_current_user( + request: Request, + session: AsyncSession = Depends(get_session), +) -> CurrentUser: + """FastAPI dependency that authenticates and returns the current user.""" + subject, display_name = await _get_subject_from_request(request) + async with session.begin(): + return await _ensure_user(session, subject, display_name) + + +async def revoke_current_request_token(request: Request) -> None: + """Revoke the current request token until its natural expiry.""" + + verified = await _get_verified_token_from_request(request) + await revoke_token_jti(verified.jwt_id, verified.expires_at) diff --git a/backend/build/lib/app/csrf.py b/backend/build/lib/app/csrf.py new file mode 100644 index 00000000..020ff069 --- /dev/null +++ b/backend/build/lib/app/csrf.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import base64 +import binascii +import hashlib +import hmac +import secrets +import time +from collections.abc import Awaitable, Callable + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from app.settings import settings + +SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE"}) +CSRF_HEADER_NAME = "X-CSRF-Token" +CSRF_TOKEN_TTL_SECONDS = 12 * 60 * 60 +CSRF_TOKEN_NONCE_BYTES = 16 + + +def _base64url_encode(value: bytes) -> str: + return base64.urlsafe_b64encode(value).decode("ascii").rstrip("=") + + +def _base64url_decode(value: str) -> bytes: + padding = "=" * (-len(value) % 4) + return base64.urlsafe_b64decode(f"{value}{padding}") + + +def _csrf_signature(secret: str, issued_at: str, nonce: str) -> bytes: + message = f"{issued_at}.{nonce}".encode("ascii") + return hmac.new( + secret.encode("utf-8"), + message, + hashlib.sha256, + ).digest() + + +def generate_csrf_token(secret: str, now: int | None = None) -> str: + """Create a short-lived token signed with the application secret.""" + issued_at = str(int(time.time() if now is None else now)) + nonce = _base64url_encode(secrets.token_bytes(CSRF_TOKEN_NONCE_BYTES)) + signature = _base64url_encode(_csrf_signature(secret, issued_at, nonce)) + return f"{issued_at}.{nonce}.{signature}" + + +def verify_csrf_token( + token: str, + secret: str, + ttl_seconds: int = CSRF_TOKEN_TTL_SECONDS, + now: int | None = None, +) -> bool: + """Validate a server-issued CSRF token without accepting caller entropy.""" + parts = token.split(".") + if len(parts) != 3: + return False + + issued_at, nonce, signature = parts + current_time = int(time.time() if now is None else now) + try: + issued_at_int = int(issued_at) + signature_bytes = _base64url_decode(signature) + except (binascii.Error, ValueError): + return False + + if current_time - issued_at_int > ttl_seconds or issued_at_int > current_time: + return False + + expected = _csrf_signature(secret, issued_at, nonce) + return hmac.compare_digest(signature_bytes, expected) + + +def make_csrf_middleware( + route_prefix: str = "/api", + header_name: str = CSRF_HEADER_NAME, + token_secret: str | None = None, + ttl_seconds: int = CSRF_TOKEN_TTL_SECONDS, +) -> Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]]: + """Require a non-simple CSRF header for state-changing API requests.""" + + async def middleware( + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Reject unsafe API requests that do not include a CSRF token.""" + if request.method.upper() in SAFE_METHODS or not request.url.path.startswith( + route_prefix + ): + return await call_next(request) + + secret = settings.app_secret if token_secret is None else token_secret + token = request.headers.get(header_name, "").strip() + if not verify_csrf_token(token, secret, ttl_seconds=ttl_seconds): + return JSONResponse( + {"detail": "CSRF token required"}, + status_code=403, + ) + + return await call_next(request) + + return middleware diff --git a/backend/build/lib/app/db.py b/backend/build/lib/app/db.py new file mode 100644 index 00000000..605ab94e --- /dev/null +++ b/backend/build/lib/app/db.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator +import asyncio +import math +import time + +import psycopg +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from app.settings import settings +from app.pooler import ( + PoolerDetectionResult, + PoolerKind, + build_admin_console_dsn, + classify_pooler_version_text, + should_route_reads_to_read_only, +) + + +def get_sync_database_url() -> str: + """Return a sync database URL for Alembic. + + Alembic uses a synchronous engine; convert an async SQLAlchemy URL to a + compatible sync URL. + """ + + # Alembic uses sync engine; convert async URL. + url = settings.database_url + if url.startswith("postgresql+asyncpg://"): + # Prefer psycopg (v3) for sync migrations. + return url.replace("postgresql+asyncpg://", "postgresql+psycopg://", 1) + return url + + +engine: AsyncEngine = create_async_engine(settings.database_url, pool_pre_ping=True) +SessionLocal = async_sessionmaker(engine, expire_on_commit=False) + +read_only_engine: AsyncEngine | None = ( + create_async_engine(settings.database_read_only_url, pool_pre_ping=True) + if settings.database_read_only_url + else None +) +ReadOnlySessionLocal = ( + async_sessionmaker(read_only_engine, expire_on_commit=False) + if read_only_engine is not None + else None +) + + +_POOLER_CACHE_TTL_SECONDS = 300.0 +_pooler_cache: PoolerDetectionResult | None = None +_pooler_cache_at: float = 0.0 +_pooler_lock = asyncio.Lock() + + +async def _probe_pooler_admin_console(admin_db: str) -> str | None: + """Best-effort probe for pooler admin console. + + Uses psycopg in a thread because some pooler admin consoles only support the + simple query protocol. + """ + + dsn, password = build_admin_console_dsn(settings.database_url, admin_db) + + raw_timeout = float(settings.db_pooler_probe_timeout_seconds) + if raw_timeout <= 0.0: + return None + + # libpq's connect_timeout is specified in whole seconds. + # Note: some PostgreSQL/libpq versions effectively treat values < 2 as 2. + timeout_seconds = max(2, math.ceil(raw_timeout)) + + def _run() -> str | None: + with psycopg.connect( + dsn, + password=password, + connect_timeout=timeout_seconds, + ) as conn: + with conn.cursor() as cur: + cur.execute("SHOW VERSION;") + row = cur.fetchone() + if not row or row[0] is None: + return None + return str(row[0]) + + try: + return await asyncio.wait_for( + asyncio.to_thread(_run), timeout=float(timeout_seconds) + 0.2 + ) + except Exception: # noqa: BLE001 + return None + + +async def get_pooler_detection() -> PoolerDetectionResult: + """Return a cached best-effort pooler detection result.""" + + global _pooler_cache_at + global _pooler_cache + + # Fast path: honor explicit configuration. + if settings.db_pooler_kind is not None: + kind = PoolerKind(settings.db_pooler_kind) + detected = kind is not PoolerKind.NONE + return PoolerDetectionResult(kind=kind, detected=detected, version_text=None) + + now = time.monotonic() + if ( + _pooler_cache is not None + and (now - _pooler_cache_at) < _POOLER_CACHE_TTL_SECONDS + ): + return _pooler_cache + + async with _pooler_lock: + now2 = time.monotonic() + if ( + _pooler_cache is not None + and (now2 - _pooler_cache_at) < _POOLER_CACHE_TTL_SECONDS + ): + return _pooler_cache + + # Try PgBouncer first, then PgCat. + for admin_db in ("pgbouncer", "pgcat"): + version_text = await _probe_pooler_admin_console(admin_db) + if version_text: + kind = classify_pooler_version_text(version_text) + _pooler_cache = PoolerDetectionResult( + kind=kind, detected=True, version_text=version_text + ) + _pooler_cache_at = time.monotonic() + return _pooler_cache + + _pooler_cache = PoolerDetectionResult( + kind=PoolerKind.UNKNOWN, detected=False, version_text=None + ) + _pooler_cache_at = time.monotonic() + return _pooler_cache + + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """FastAPI dependency that yields an AsyncSession.""" + async with SessionLocal() as session: + yield session + + +async def get_read_session() -> AsyncGenerator[AsyncSession, None]: + """FastAPI dependency that yields a read session. + + - If DATABASE_READ_ONLY_URL is not configured, it falls back to the primary + session. + - If DB_READ_ROUTING=auto, it uses the read-only DSN only when a pooler is + detected (or explicitly hinted). + """ + + if ReadOnlySessionLocal is None: + async with SessionLocal() as session: + yield session + return + + detection = await get_pooler_detection() + use_read_only = should_route_reads_to_read_only( + mode=settings.db_read_routing, + read_only_url=settings.database_read_only_url, + pooler_detected=detection.detected, + ) + + maker = ReadOnlySessionLocal if use_read_only else SessionLocal + async with maker() as session: + yield session diff --git a/backend/build/lib/app/db_introspect.py b/backend/build/lib/app/db_introspect.py new file mode 100644 index 00000000..934a5cb5 --- /dev/null +++ b/backend/build/lib/app/db_introspect.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Literal +from urllib.parse import urlparse + +from app.dsn_redaction import redact_dsn_error_message +from app.pg_introspect.introspect import introspect_postgres +from app.snowflake_introspect import introspect_snowflake + +DatabaseDialect = Literal["postgresql", "snowflake"] + + +def detect_dsn_dialect(dsn: str) -> DatabaseDialect: + """Infer the source database dialect from a connection string scheme.""" + + scheme = urlparse(dsn).scheme.lower().split("+", 1)[0] + if scheme in ("postgres", "postgresql"): + return "postgresql" + if scheme == "snowflake": + return "snowflake" + raise ValueError(f"unsupported database DSN scheme: {scheme or ''}") + + +async def introspect_database(dsn: str, schema_filter: str | None) -> dict: + """Introspect a supported database and return the common snapshot JSON.""" + + try: + dialect = detect_dsn_dialect(dsn) + if dialect == "snowflake": + return await introspect_snowflake(dsn, schema_filter) + return await introspect_postgres(dsn, schema_filter) + except Exception as exc: + message = str(exc) or type(exc).__name__ + raise RuntimeError(redact_dsn_error_message(message, dsn)) from None diff --git a/backend/build/lib/app/ddl/__init__.py b/backend/build/lib/app/ddl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/build/lib/app/ddl/export.py b/backend/build/lib/app/ddl/export.py new file mode 100644 index 00000000..11fdb993 --- /dev/null +++ b/backend/build/lib/app/ddl/export.py @@ -0,0 +1,617 @@ +from __future__ import annotations + +import re +from typing import Literal + +DdlDialect = Literal["postgresql", "snowflake"] + + +def _normalize_dialect(dialect: str) -> DdlDialect: + normalized = dialect.lower().replace("_", "-") + if normalized in ("postgres", "postgresql", "pg"): + return "postgresql" + if normalized in ("snowflake", "sf"): + return "snowflake" + raise ValueError(f"unsupported DDL dialect: {dialect}") + + +def _snapshot_source_dialect(snapshot: dict) -> DdlDialect: + for key in ("source_dialect", "database_dialect", "dialect"): + value = snapshot.get(key) + if isinstance(value, str): + try: + return _normalize_dialect(value) + except ValueError: + continue + return "postgresql" + + +def _q(ident: str) -> str: + """Quote a SQL identifier.""" + + # Quote identifier with double-quotes, escaping internal quotes. + return '"' + ident.replace('"', '""') + '"' + + +def _qname(schema: str, name: str) -> str: + """Quote a schema-qualified name.""" + return f"{_q(schema)}.{_q(name)}" + + +def _tablespace_clause(tablespace: object) -> str: + return f" TABLESPACE {_q(tablespace)}" if isinstance(tablespace, str) else "" + + +_CREATE_INDEX_RE = re.compile( + r"^CREATE\s+(UNIQUE\s+)?INDEX\s+(?!CONCURRENTLY\b)", + flags=re.IGNORECASE, +) + + +def _index_def_with_concurrently(index_def: str) -> str: + """Add CONCURRENTLY to a PostgreSQL CREATE INDEX statement.""" + + def replacement(match: re.Match[str]) -> str: + unique = match.group(1) or "" + return f"CREATE {unique}INDEX CONCURRENTLY " + + return _CREATE_INDEX_RE.sub(replacement, index_def, count=1) + + +def _normalize_type_text(data_type: str) -> str: + return re.sub(r"\s+", " ", data_type.strip().lower()) + + +def _postgres_type_to_snowflake(column: dict) -> str: + base_type = column.get("domain_base_type") + if isinstance(base_type, str): + return _postgres_type_to_snowflake({**column, "data_type": base_type}) + + data_type = column.get("data_type") + if not isinstance(data_type, str): + return "VARCHAR" + + normalized = _normalize_type_text(data_type) + if ( + column.get("array_dimensions") + or normalized.endswith("[]") + or normalized.startswith("array") + ): + return "ARRAY" + + exact = { + "smallint": "NUMBER(5,0)", + "integer": "NUMBER(10,0)", + "bigint": "NUMBER(19,0)", + "real": "FLOAT", + "double precision": "FLOAT", + "boolean": "BOOLEAN", + "text": "VARCHAR", + "json": "VARIANT", + "jsonb": "VARIANT", + "bytea": "BINARY", + "date": "DATE", + "uuid": "VARCHAR(36)", + "inet": "VARCHAR", + "cidr": "VARCHAR", + "macaddr": "VARCHAR", + "macaddr8": "VARCHAR", + "xml": "VARCHAR", + } + if normalized in exact: + return exact[normalized] + + numeric = re.match( + r"^(numeric|decimal)\s*(?:\((\d+)(?:\s*,\s*(\d+))?\))?$", + normalized, + ) + if numeric: + precision = numeric.group(2) + scale = numeric.group(3) + if precision and scale: + return f"NUMBER({precision},{scale})" + if precision: + return f"NUMBER({precision},0)" + return "NUMBER" + + varchar = re.match(r"^(character varying|varchar)\s*(?:\((\d+)\))?$", normalized) + if varchar: + return f"VARCHAR({varchar.group(2)})" if varchar.group(2) else "VARCHAR" + + char = re.match(r"^(character|char)\s*(?:\((\d+)\))?$", normalized) + if char: + return f"CHAR({char.group(2)})" if char.group(2) else "CHAR" + + if normalized.startswith("timestamp") and "with time zone" in normalized: + return "TIMESTAMP_TZ" + if normalized.startswith("timestamp"): + return "TIMESTAMP_NTZ" + if normalized.startswith("time"): + return "TIME" + if normalized.startswith("interval"): + return "VARCHAR" + + if column.get("type_kind") == "e": + return "VARCHAR" + + return "VARCHAR" + + +def _snowflake_type_to_postgres(column: dict) -> str: + data_type = column.get("data_type") + if not isinstance(data_type, str): + return "text" + + normalized = _normalize_type_text(data_type) + exact = { + "boolean": "boolean", + "bool": "boolean", + "date": "date", + "time": "time", + "float": "double precision", + "float4": "real", + "float8": "double precision", + "double": "double precision", + "double precision": "double precision", + "real": "real", + "binary": "bytea", + "varbinary": "bytea", + "variant": "jsonb", + "object": "jsonb", + "array": "jsonb", + "geography": "jsonb", + "geometry": "jsonb", + } + if normalized in exact: + return exact[normalized] + + number = re.match( + r"^(number|numeric|decimal)\s*(?:\((\d+)(?:\s*,\s*(\d+))?\))?$", + normalized, + ) + if number: + precision = number.group(2) + scale = number.group(3) + if precision and scale: + return f"numeric({precision},{scale})" + if precision: + return f"numeric({precision},0)" + return "numeric" + + varchar = re.match(r"^(varchar|string|text)\s*(?:\((\d+)\))?$", normalized) + if varchar: + return f"character varying({varchar.group(2)})" if varchar.group(2) else "text" + + char = re.match(r"^(char|character)\s*(?:\((\d+)\))?$", normalized) + if char: + return f"character({char.group(2)})" if char.group(2) else "character" + + if normalized.startswith("timestamp_tz") or normalized.startswith("timestamp_ltz"): + return "timestamp with time zone" + if normalized.startswith("timestamp_ntz") or normalized.startswith("timestamp"): + return "timestamp without time zone" + + return "text" + + +def _mapped_data_type(column: dict, source: DdlDialect, target: DdlDialect) -> str: + data_type = column.get("data_type") + if not isinstance(data_type, str): + return "text" if target == "postgresql" else "VARCHAR" + if source == target: + return data_type + if source == "postgresql" and target == "snowflake": + return _postgres_type_to_snowflake(column) + if source == "snowflake" and target == "postgresql": + return _snowflake_type_to_postgres(column) + return data_type + + +def _column_default_clause(default_expr: object, target: DdlDialect) -> str | None: + if not isinstance(default_expr, str): + return None + expr = default_expr.strip() + if not expr: + return None + if target == "postgresql": + return f"DEFAULT {expr}" + + upper = expr.upper() + if "::" in expr or "NEXTVAL(" in upper: + return None + if re.fullmatch(r"[-+]?\d+(?:\.\d+)?", expr): + return f"DEFAULT {expr}" + if re.fullmatch(r"'(?:''|[^'])*'", expr): + return f"DEFAULT {expr}" + if upper in ("TRUE", "FALSE"): + return f"DEFAULT {upper}" + if upper in ("NOW()", "CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP()"): + return "DEFAULT CURRENT_TIMESTAMP()" + if upper in ("CURRENT_DATE", "CURRENT_DATE()"): + return "DEFAULT CURRENT_DATE()" + if upper in ("CURRENT_TIME", "CURRENT_TIME()"): + return "DEFAULT CURRENT_TIME()" + if upper in ("GEN_RANDOM_UUID()", "UUID_GENERATE_V4()"): + return "DEFAULT UUID_STRING()" + return None + + +def _snapshot_tables(snapshot: dict) -> list[dict]: + relations = snapshot.get("relations", []) + return [ + r + for r in relations + if isinstance(r, dict) and r.get("relation_kind") in ("r", "p") + ] + + +def _group_by_relation(rows: object) -> dict[int, list[dict]]: + grouped: dict[int, list[dict]] = {} + if not isinstance(rows, list): + return grouped + for row in rows: + if not isinstance(row, dict): + continue + oid = row.get("relation_oid") + if isinstance(oid, int): + grouped.setdefault(oid, []).append(row) + return grouped + + +def _constraint_column_names( + constraint: dict, cols_by_oid: dict[int, list[dict]] +) -> list[str]: + oid = constraint.get("relation_oid") + attnums = constraint.get("constrained_attnums") + if not isinstance(oid, int) or not isinstance(attnums, list): + return [] + + cols_by_attnum: dict[int, str] = {} + for col in cols_by_oid.get(oid, []): + position = col.get("column_position") + name = col.get("column_name") + if isinstance(position, int) and isinstance(name, str): + cols_by_attnum[position] = name + + names: list[str] = [] + for attnum in attnums: + if not isinstance(attnum, int): + return [] + name = cols_by_attnum.get(attnum) + if name is None: + return [] + names.append(name) + return names + + +def _render_schemas(tables: list[dict], lines: list[str]) -> None: + schemas: set[str] = set() + for table in tables: + schema_name = table.get("schema_name") + if isinstance(schema_name, str): + schemas.add(schema_name) + + for s in sorted(schemas): + lines.append(f"CREATE SCHEMA IF NOT EXISTS {_q(s)};") + if schemas: + lines.append("") + + +def _render_foreign_keys(constraints: list[dict], lines: list[str]) -> None: + fk_cons = [c for c in constraints if c.get("constraint_type") == "f"] + if fk_cons: + lines.append("-- Foreign keys") + for con in fk_cons: + schema = con.get("schema_name") + table = con.get("relation_name") + cname = con.get("constraint_name") + cdef = con.get("constraint_def") + if not ( + isinstance(schema, str) + and isinstance(table, str) + and isinstance(cname, str) + and isinstance(cdef, str) + ): + continue + lines.append( + f"ALTER TABLE {_qname(schema, table)} ADD CONSTRAINT {_q(cname)} {cdef};" + ) + if fk_cons: + lines.append("") + + +def _render_indexes_pg(indexes: list[dict], lines: list[str]) -> None: + if indexes: + lines.append("-- Indexes (CONCURRENTLY; run outside a transaction)") + for ix in indexes: + ix_def = ix.get("index_def") + if not isinstance(ix_def, str): + continue + ix_def = ix_def.strip().rstrip(";") + ix_def = _index_def_with_concurrently(ix_def) + table_options = _tablespace_clause(ix.get("index_tablespace_name")) + if table_options and " TABLESPACE " not in ix_def.upper(): + ix_def = f"{ix_def}{table_options}" + lines.append(ix_def + ";") + + +def _render_table_columns_pg( + oid: int, + cols_by_oid: dict[int, list[dict]], + source_dialect: DdlDialect, +) -> list[str]: + col_defs: list[str] = [] + for c in sorted( + cols_by_oid.get(oid, []), + key=lambda x: int(x.get("column_position") or 0), + ): + col_name = c.get("column_name") + if not isinstance(col_name, str): + continue + data_type = _mapped_data_type(c, source_dialect, "postgresql") + parts = [f"{_q(col_name)} {data_type}"] + if c.get("has_default"): + default_clause = _column_default_clause(c.get("default_expr"), "postgresql") + if default_clause: + parts.append(default_clause) + if c.get("is_not_null") is True: + parts.append("NOT NULL") + col_defs.append(" ".join(parts)) + return col_defs + + +def _render_table_constraints_pg( + oid: int, + constraints_by_oid: dict[int, list[dict]], +) -> list[str]: + table_cons: list[str] = [] + for con in constraints_by_oid.get(oid, []): + ctype = con.get("constraint_type") + if ctype not in ("p", "u", "c"): + continue + cname = con.get("constraint_name") + cdef = con.get("constraint_def") + if isinstance(cname, str) and isinstance(cdef, str): + table_cons.append(f"CONSTRAINT {_q(cname)} {cdef}") + return table_cons + + +def snapshot_json_to_sql(snapshot: dict, target_dialect: str = "postgresql") -> str: + """Render a captured schema snapshot as SQL for the requested dialect.""" + target = _normalize_dialect(target_dialect) + if target == "snowflake": + return _snapshot_json_to_snowflake_sql(snapshot) + return _snapshot_json_to_postgresql_sql(snapshot) + + +def _render_table_pg( + t: dict, + cols_by_oid: dict[int, list[dict]], + constraints_by_oid: dict[int, list[dict]], + source_dialect: DdlDialect, +) -> list[str]: + schema = t.get("schema_name") + name = t.get("relation_name") + oid = t.get("relation_oid") + kind = t.get("relation_kind") + tablespace = t.get("tablespace_name") + partition_key = t.get("partition_key") + partition_bound = t.get("partition_bound") + partition_parent_schema = t.get("partition_parent_schema") + partition_parent_name = t.get("partition_parent_name") + is_partition = t.get("is_partition") is True + if not (isinstance(schema, str) and isinstance(name, str) and isinstance(oid, int)): + return [] + + lines: list[str] = [] + table_options = _tablespace_clause(tablespace) + if ( + is_partition + and isinstance(partition_parent_schema, str) + and isinstance(partition_parent_name, str) + and isinstance(partition_bound, str) + ): + partition_clause = ( + f" PARTITION BY {partition_key}" if isinstance(partition_key, str) else "" + ) + lines.append( + f"CREATE TABLE IF NOT EXISTS {_qname(schema, name)} PARTITION OF {_qname(partition_parent_schema, partition_parent_name)} {partition_bound}{partition_clause}{table_options};" + ) + lines.append("") + return lines + + if kind == "p" and not isinstance(partition_key, str): + lines.append( + f"-- NOTE: {_qname(schema, name)} is partitioned; partition definition not included in MVP export" + ) + + col_defs = _render_table_columns_pg(oid, cols_by_oid, source_dialect) + table_cons = _render_table_constraints_pg(oid, constraints_by_oid) + + all_defs = col_defs + table_cons + lines.append(f"CREATE TABLE IF NOT EXISTS {_qname(schema, name)} (") + for i, d in enumerate(all_defs): + comma = "," if i < len(all_defs) - 1 else "" + lines.append(f" {d}{comma}") + partition_clause = ( + f" PARTITION BY {partition_key}" + if kind == "p" and isinstance(partition_key, str) + else "" + ) + lines.append(f"){partition_clause}{table_options};") + lines.append("") + return lines + + +def _snapshot_json_to_postgresql_sql(snapshot: dict) -> str: + """Generate PostgreSQL DDL from a captured snapshot. + + This is MVP-grade forward engineering (export): + - Creates schemas and tables (columns only) + - Adds PK/UNIQUE/CHECK inside CREATE TABLE + - Adds FKs after all tables (order-safe) + - Adds indexes using saved pg_get_indexdef output + + Limitations (intentional, MVP): partitioning clauses and some table options are not reconstructed. + """ + + relations = snapshot.get("relations", []) + columns = snapshot.get("columns", []) + constraints = snapshot.get("constraints", []) + indexes = snapshot.get("indexes", []) + source_dialect = _snapshot_source_dialect(snapshot) + + tables = [r for r in relations if r.get("relation_kind") in ("r", "p")] + + cols_by_oid = _group_by_relation(columns) + constraints_by_oid = _group_by_relation(constraints) + + lines: list[str] = [] + lines.append("-- Generated by pg-erd-cloud (MVP)\n") + _render_schemas(tables, lines) + + # CREATE TABLE + inline constraints (PK/UNIQUE/CHECK) + for t in tables: + lines.extend( + _render_table_pg(t, cols_by_oid, constraints_by_oid, source_dialect) + ) + + # FKs after tables + _render_foreign_keys(constraints, lines) + + # Indexes (use saved pg_get_indexdef output) + _render_indexes_pg(indexes, lines) + + lines.append("") + return "\n".join(lines) + + +def _render_table_columns_snowflake( + oid: int, + cols_by_oid: dict[int, list[dict]], + source_dialect: DdlDialect, +) -> list[str]: + col_defs: list[str] = [] + for c in sorted( + cols_by_oid.get(oid, []), + key=lambda x: int(x.get("column_position") or 0), + ): + col_name = c.get("column_name") + if not isinstance(col_name, str): + continue + parts = [f"{_q(col_name)} {_mapped_data_type(c, source_dialect, 'snowflake')}"] + if c.get("has_default"): + default_clause = _column_default_clause(c.get("default_expr"), "snowflake") + if default_clause: + parts.append(default_clause) + if c.get("is_not_null") is True: + parts.append("NOT NULL") + col_defs.append(" ".join(parts)) + return col_defs + + +def _render_table_constraints_snowflake( + oid: int, + constraints_by_oid: dict[int, list[dict]], + cols_by_oid: dict[int, list[dict]], +) -> tuple[list[str], list[str]]: + table_cons: list[str] = [] + skipped_checks: list[str] = [] + for con in constraints_by_oid.get(oid, []): + ctype = con.get("constraint_type") + cname = con.get("constraint_name") + cdef = con.get("constraint_def") + if not (isinstance(cname, str) and isinstance(cdef, str)): + continue + if ctype in ("p", "u"): + col_names = _constraint_column_names(con, cols_by_oid) + if col_names: + keyword = "PRIMARY KEY" if ctype == "p" else "UNIQUE" + quoted_cols = ", ".join(_q(name) for name in col_names) + table_cons.append(f"CONSTRAINT {_q(cname)} {keyword} ({quoted_cols})") + else: + table_cons.append(f"CONSTRAINT {_q(cname)} {cdef}") + elif ctype == "c": + skipped_checks.append(cname) + return table_cons, skipped_checks + + +def _render_indexes_snowflake(indexes: list[dict], lines: list[str]) -> None: + if indexes: + lines.append("-- Indexes") + for ix in indexes: + if not isinstance(ix, dict): + continue + ix_name = ix.get("index_name") + table_schema = ix.get("table_schema_name") + table_name = ix.get("table_name") + if ( + isinstance(ix_name, str) + and isinstance(table_schema, str) + and isinstance(table_name, str) + ): + lines.append( + f"-- NOTE: PostgreSQL index {_q(ix_name)} on {_qname(table_schema, table_name)} is not emitted for Snowflake; consider clustering/search optimization as needed." + ) + else: + lines.append( + "-- NOTE: PostgreSQL index metadata is not emitted for Snowflake." + ) + + +def _snapshot_json_to_snowflake_sql(snapshot: dict) -> str: + """Generate Snowflake DDL from a captured PostgreSQL/Snowflake snapshot.""" + + source_dialect = _snapshot_source_dialect(snapshot) + columns = snapshot.get("columns", []) + constraints = snapshot.get("constraints", []) + indexes = snapshot.get("indexes", []) + + tables = _snapshot_tables(snapshot) + cols_by_oid = _group_by_relation(columns) + constraints_by_oid = _group_by_relation(constraints) + + lines: list[str] = [] + lines.append("-- Generated by pg-erd-cloud (MVP) for Snowflake\n") + _render_schemas(tables, lines) + + for t in tables: + schema = t.get("schema_name") + name = t.get("relation_name") + oid = t.get("relation_oid") + if not ( + isinstance(schema, str) and isinstance(name, str) and isinstance(oid, int) + ): + continue + + col_defs = _render_table_columns_snowflake(oid, cols_by_oid, source_dialect) + table_cons, skipped_checks = _render_table_constraints_snowflake( + oid, constraints_by_oid, cols_by_oid + ) + + all_defs = col_defs + table_cons + lines.append(f"CREATE TABLE IF NOT EXISTS {_qname(schema, name)} (") + for i, d in enumerate(all_defs): + comma = "," if i < len(all_defs) - 1 else "" + lines.append(f" {d}{comma}") + lines.append(");") + for cname in skipped_checks: + lines.append( + f"-- NOTE: skipped PostgreSQL CHECK constraint {_q(cname)} on {_qname(schema, name)} for Snowflake export." + ) + if isinstance(t.get("tablespace_name"), str): + lines.append( + f"-- NOTE: skipped PostgreSQL TABLESPACE {_q(t['tablespace_name'])} on {_qname(schema, name)} for Snowflake export." + ) + if t.get("relation_kind") == "p" or t.get("is_partition") is True: + lines.append( + f"-- NOTE: skipped PostgreSQL partition metadata on {_qname(schema, name)} for Snowflake export." + ) + lines.append("") + + _render_foreign_keys(constraints, lines) + + _render_indexes_snowflake(indexes, lines) + + lines.append("") + return "\n".join(lines) diff --git a/backend/build/lib/app/dsn_redaction.py b/backend/build/lib/app/dsn_redaction.py new file mode 100644 index 00000000..3342c3ae --- /dev/null +++ b/backend/build/lib/app/dsn_redaction.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import re +from urllib.parse import quote, quote_plus, unquote, unquote_plus, urlsplit + +_SECRET_KEY_PATTERN = re.compile( + r"(?:pass(?:word|wd)?|pwd|token|secret|private[_-]?key|api[_-]?key|" + r"access[_-]?key|auth(?:entication)?)", + re.IGNORECASE, +) +_SECRET_ASSIGNMENT_PATTERN = re.compile( + r"(?P\b[\w.-]*(?:pass(?:word|wd)?|pwd|token|secret|private[_-]?key|" + r"api[_-]?key|access[_-]?key|auth(?:entication)?)[\w.-]*\s*[:=]\s*)" + r"(?P[^&\s,;\"'<>]+)", + re.IGNORECASE, +) + + +def _password_candidates_from_dsn(dsn: str) -> set[str]: + candidates: set[str] = set() + parsed = urlsplit(dsn) + + if parsed.password: + candidates.add(parsed.password) + candidates.add(quote(parsed.password, safe="")) + + if "@" in parsed.netloc: + userinfo = parsed.netloc.rsplit("@", 1)[0] + if ":" in userinfo: + raw_password = userinfo.split(":", 1)[1] + candidates.add(raw_password) + candidates.add(unquote(raw_password)) + + for part in parsed.query.split("&"): + key, sep, raw_value = part.partition("=") + if not sep: + continue + if not _SECRET_KEY_PATTERN.search(unquote_plus(key)): + continue + decoded_value = unquote_plus(raw_value) + candidates.add(raw_value) + candidates.add(decoded_value) + candidates.add(quote(decoded_value, safe="")) + candidates.add(quote_plus(decoded_value, safe="")) + + return {candidate for candidate in candidates if candidate} + + +def redact_dsn_error_message(error_message: str, dsn: str) -> str: + """Redact DSN-derived secrets from a driver error message.""" + + redacted = error_message + for secret in sorted(_password_candidates_from_dsn(dsn), key=len, reverse=True): + redacted = redacted.replace(secret, "***") + return _SECRET_ASSIGNMENT_PATTERN.sub(r"\g***", redacted) diff --git a/backend/build/lib/app/hypercorn_config.py b/backend/build/lib/app/hypercorn_config.py new file mode 100644 index 00000000..02106b0a --- /dev/null +++ b/backend/build/lib/app/hypercorn_config.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + + +def _int_env(name: str, default: str) -> int: + """Parse an integer environment variable with a safe fallback. + + Hypercorn config modules are imported at server startup. A bad env value + should not crash the process with an unhelpful ValueError. + """ + + raw = os.getenv(name) + value = (raw if raw is not None else default).strip() + try: + parsed = int(value) + except ValueError as exc: + try: + parsed = int(str(default).strip()) + except ValueError as default_exc: + parsed = 1 + logger.warning( + "Invalid %s=%r and default=%r; falling back to 1 (%s)", + name, + raw, + default, + default_exc, + ) + else: + logger.warning( + "Invalid %s=%r; falling back to default=%r (%s)", + name, + raw, + default, + exc, + ) + + return max(1, parsed) + + +# Prefer explicit hypercorn knob; fall back to a common convention. +workers = _int_env("HYPERCORN_WORKERS", os.getenv("WEB_CONCURRENCY", "1")) diff --git a/backend/build/lib/app/jobs/__init__.py b/backend/build/lib/app/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/build/lib/app/jobs/snapshot_job.py b/backend/build/lib/app/jobs/snapshot_job.py new file mode 100644 index 00000000..0aa512bb --- /dev/null +++ b/backend/build/lib/app/jobs/snapshot_job.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import datetime as dt +import uuid +from collections.abc import Callable + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.dsn_redaction import redact_dsn_error_message +from app.models import ( + DbConnection, + JobQueue, + SchemaSnapshot, + SchemaSnapshotData, +) +from app.db_introspect import introspect_database +from app.security import decrypt_text + +_redact_snapshot_error_message = redact_dsn_error_message + + +async def handle_snapshot_job( + session_factory: Callable[[], AsyncSession], + job: JobQueue, +) -> None: + """Run a schema snapshot job and persist the resulting JSON.""" + payload = job.payload_json + snapshot_id = uuid.UUID(payload["schema_snapshot_uuid"]) + async with session_factory() as session: + async with session.begin(): + snapshot = await session.get(SchemaSnapshot, snapshot_id) + if snapshot is None: + raise RuntimeError("snapshot not found") + snapshot.status = "running" + snapshot.started_at = dt.datetime.now(dt.timezone.utc) + + conn = await session.get(DbConnection, snapshot.db_connection_uuid) + if conn is None: + raise RuntimeError("db connection not found") + + dsn = decrypt_text(conn.dsn_ciphertext, conn.dsn_nonce) + schema_filter = snapshot.schema_filter + + # Long-running IO: do it outside a DB transaction. + try: + data = await introspect_database(dsn, schema_filter) + except Exception as e: # noqa: BLE001 + error_message = _redact_snapshot_error_message(str(e), dsn) + async with session_factory() as session: + async with session.begin(): + snapshot = await session.get(SchemaSnapshot, snapshot_id) + if snapshot is None: + raise RuntimeError(error_message) from None + snapshot.status = "failed" + snapshot.error_message = error_message + snapshot.finished_at = dt.datetime.now(dt.timezone.utc) + raise RuntimeError(error_message) from None + + async with session_factory() as session: + async with session.begin(): + snapshot = await session.get(SchemaSnapshot, snapshot_id) + if snapshot is None: + raise RuntimeError("snapshot not found") + + snapshot.status = "succeeded" + snapshot.finished_at = dt.datetime.now(dt.timezone.utc) + snapshot.error_message = None + + existing = await session.get(SchemaSnapshotData, snapshot_id) + if existing is None: + session.add( + SchemaSnapshotData( + schema_snapshot_uuid=snapshot_id, + snapshot_json=data, + created_at=dt.datetime.now(dt.timezone.utc), + ) + ) + else: + existing.snapshot_json = data diff --git a/backend/build/lib/app/jobs/valkey_queue.py b/backend/build/lib/app/jobs/valkey_queue.py new file mode 100644 index 00000000..2f2eb4f5 --- /dev/null +++ b/backend/build/lib/app/jobs/valkey_queue.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import datetime as dt +import importlib +import logging +import uuid +from collections.abc import Iterable +from typing import Any + +from app.settings import settings + +_logger = logging.getLogger(__name__) + +_POP_DUE_JOB_SCRIPT = """ +local ids = redis.call('ZRANGEBYSCORE', KEYS[1], '-inf', ARGV[1], 'LIMIT', 0, 1) +if #ids == 0 then + return nil +end +if redis.call('ZREM', KEYS[1], ids[1]) == 1 then + return ids[1] +end +return nil +""" + + +class ValkeyQueueUnavailable(RuntimeError): + """Raised when Valkey is selected but the Python client is unavailable.""" + + +def _parse_sentinel_hosts(raw: str | None) -> list[tuple[str, int]]: + """Parse VALKEY_SENTINEL_HOSTS as comma-separated host:port entries.""" + + if not raw: + return [] + + hosts: list[tuple[str, int]] = [] + for part in raw.split(","): + item = part.strip() + if not item: + continue + host, sep, port_text = item.rpartition(":") + if not sep or not host: + raise ValueError("VALKEY_SENTINEL_HOSTS entries must be host:port") + port = int(port_text) + if port <= 0 or port > 65535: + raise ValueError("VALKEY_SENTINEL_HOSTS port out of range") + hosts.append((host, port)) + return hosts + + +def valkey_queue_enabled() -> bool: + """Return whether workers should use Valkey as a queue signal path.""" + + if settings.job_queue_backend != "valkey": + return False + return bool(settings.valkey_url or settings.valkey_sentinel_hosts) + + +def valkey_queue_mode() -> str: + """Return the configured Valkey connection mode for diagnostics.""" + + if settings.valkey_sentinel_hosts: + return "sentinel" + if settings.valkey_url: + return "url" + return "disabled" + + +def valkey_queue_config_summary() -> dict[str, object]: + """Expose non-secret Valkey queue configuration for reports/tests.""" + + sentinel_hosts = _parse_sentinel_hosts(settings.valkey_sentinel_hosts) + return { + "enabled": valkey_queue_enabled(), + "mode": valkey_queue_mode(), + "queue_key": settings.valkey_queue_key, + "sentinel_master": settings.valkey_sentinel_master, + "sentinel_count": len(sentinel_hosts), + "lock_ttl_seconds": settings.valkey_lock_ttl_seconds, + } + + +def _load_redis_module() -> Any: + try: + return importlib.import_module("redis.asyncio") + except ModuleNotFoundError as exc: + raise ValkeyQueueUnavailable( + "Valkey queue backend requires redis-py with asyncio support" + ) from exc + + +async def _client() -> Any: + redis_asyncio = _load_redis_module() + sentinel_hosts = _parse_sentinel_hosts(settings.valkey_sentinel_hosts) + if sentinel_hosts: + if not settings.valkey_sentinel_master: + raise ValueError( + "VALKEY_SENTINEL_MASTER is required with VALKEY_SENTINEL_HOSTS" + ) + sentinel_mod = importlib.import_module("redis.asyncio.sentinel") + sentinel = sentinel_mod.Sentinel(sentinel_hosts) + return sentinel.master_for(settings.valkey_sentinel_master) + if not settings.valkey_url: + raise ValueError("VALKEY_URL is required when job_queue_backend=valkey") + return redis_asyncio.from_url(settings.valkey_url) + + +async def _close_client(client: Any) -> None: + close = getattr(client, "aclose", None) or getattr(client, "close", None) + if close is None: + return None + result = close() + if hasattr(result, "__await__"): + _ = await result + return None + + +async def enqueue_job_signal( + job_queue_uuid: uuid.UUID, + run_after: dt.datetime, +) -> bool: + """Best-effort signal that a DB-backed job is due through Valkey.""" + + if not valkey_queue_enabled(): + return False + + client: Any | None = None + try: + client = await _client() + await client.zadd( + settings.valkey_queue_key, + {str(job_queue_uuid): run_after.timestamp()}, + ) + return True + except Exception: # noqa: BLE001 + _logger.warning("Valkey job enqueue signal failed", exc_info=True) + return False + finally: + if client is not None: + await _close_client(client) + + +async def pop_due_job_signal( + now: dt.datetime | None = None, +) -> uuid.UUID | None: + """Pop one due job ID from Valkey, if the optional backend is configured.""" + + if not valkey_queue_enabled(): + return None + + current = now or dt.datetime.now(dt.timezone.utc) + client: Any | None = None + try: + client = await _client() + value = await client.eval( + _POP_DUE_JOB_SCRIPT, + 1, + settings.valkey_queue_key, + current.timestamp(), + ) + except Exception: # noqa: BLE001 + _logger.warning("Valkey job pop signal failed", exc_info=True) + return None + finally: + if client is not None: + await _close_client(client) + + if value is None: + return None + if isinstance(value, bytes): + value = value.decode("utf-8") + try: + return uuid.UUID(str(value)) + except ValueError: + _logger.warning("Valkey queue returned an invalid job UUID: %r", value) + return None + + +def format_sentinel_hosts(hosts: Iterable[tuple[str, int]]) -> str: + """Format sentinel hosts without exposing credentials.""" + + return ",".join(f"{host}:{port}" for host, port in hosts) diff --git a/backend/build/lib/app/jobs/worker.py b/backend/build/lib/app/jobs/worker.py new file mode 100644 index 00000000..29eec7e3 --- /dev/null +++ b/backend/build/lib/app/jobs/worker.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import asyncio +import datetime as dt +import logging +import time +from collections.abc import Awaitable, Callable +from collections.abc import Mapping +from typing import TypeAlias + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import JobQueue +from app.jobs.valkey_queue import ( + pop_due_job_signal, + valkey_queue_enabled, +) +from app.metrics import ( + JOB_QUEUE_JOBS_TOTAL, + JOB_QUEUE_PROCESSING_SECONDS, + JOB_QUEUE_WAIT_SECONDS, +) +from app.settings import settings + +_logger = logging.getLogger(__name__) + +Handler: TypeAlias = Callable[[Callable[[], AsyncSession], JobQueue], Awaitable[None]] + + +def _mark_job_running(job: JobQueue) -> JobQueue: + job.status = "running" + job.started_at = dt.datetime.now(dt.timezone.utc) + job.attempt_count = int(job.attempt_count) + 1 + + if settings.observability_metrics_enabled: + try: + wait_s = (job.started_at - job.run_after).total_seconds() + if wait_s >= 0: + JOB_QUEUE_WAIT_SECONDS.labels(job_type=job.job_type).observe(wait_s) + except Exception: # noqa: BLE001 + # Never fail job claiming due to metrics. + _logger.debug("job queue wait metric observation failed", exc_info=True) + return job + + +async def _claim_job_by_id( + session: AsyncSession, + job_queue_uuid: object, +) -> JobQueue | None: + """Claim a specific queued job if it is due and not already locked.""" + + row = await session.execute( + text(""" + SELECT job_queue_uuid + FROM job_queue + WHERE + job_queue_uuid = :job_queue_uuid + AND status = 'queued' + AND run_after <= now() + FOR UPDATE SKIP LOCKED + LIMIT 1 + """), + {"job_queue_uuid": job_queue_uuid}, + ) + job_id = row.scalar_one_or_none() + if job_id is None: + return None + + job = await session.get(JobQueue, job_id) + if job is None: + return None + return _mark_job_running(job) + + +async def claim_one_job(session: AsyncSession) -> JobQueue | None: + """Claim one queued job using FOR UPDATE SKIP LOCKED.""" + + if valkey_queue_enabled(): + signaled_job_id = await pop_due_job_signal() + if signaled_job_id is not None: + job = await _claim_job_by_id(session, signaled_job_id) + if job is not None: + return job + + # Transaction: claim a queued job using SKIP LOCKED (non-blocking) + # We use raw SQL to leverage FOR UPDATE SKIP LOCKED reliably. + row = await session.execute( + text(""" + SELECT job_queue_uuid + FROM job_queue + WHERE status = 'queued' AND run_after <= now() + ORDER BY run_after ASC + FOR UPDATE SKIP LOCKED + LIMIT 1 + """) + ) + job_id = row.scalar_one_or_none() + if job_id is None: + return None + + job = await session.get(JobQueue, job_id) + if job is None: + return None + return _mark_job_running(job) + + +def _publish_job_metrics( + *, + job_type: str, + outcome: str, + duration_s: float | None, +) -> None: + """Publish job metrics (best-effort) when metrics are enabled.""" + if not settings.observability_metrics_enabled: + return + + JOB_QUEUE_JOBS_TOTAL.labels(job_type=job_type, outcome=outcome).inc() + if duration_s is not None: + JOB_QUEUE_PROCESSING_SECONDS.labels( + job_type=job_type, + outcome=outcome, + ).observe(duration_s) + + +async def run_worker_forever( + session_factory: Callable[[], AsyncSession], + handlers: Mapping[str, Handler], + poll_interval_s: float = 1.0, +) -> None: + """Continuously poll the queue and dispatch jobs to handlers.""" + while True: + async with session_factory() as session: + async with session.begin(): + job = await claim_one_job(session) + + if job is None: + await asyncio.sleep(poll_interval_s) + continue + + handler = handlers.get(job.job_type) + if handler is None: + async with session.begin(): + job.status = "failed" + job.last_error = f"Unknown job_type: {job.job_type}" + job.finished_at = dt.datetime.now(dt.timezone.utc) + continue + + started = time.perf_counter() + try: + await handler(session_factory, job) + duration_s = time.perf_counter() - started + async with session.begin(): + job.status = "succeeded" + job.last_error = None + job.finished_at = dt.datetime.now(dt.timezone.utc) + + _publish_job_metrics( + job_type=job.job_type, + outcome="succeeded", + duration_s=duration_s, + ) + except Exception as e: # noqa: BLE001 + duration_s = time.perf_counter() - started + async with session.begin(): + job.status = "failed" + job.last_error = str(e) + job.finished_at = dt.datetime.now(dt.timezone.utc) + + _publish_job_metrics( + job_type=job.job_type, + outcome="failed", + duration_s=duration_s, + ) diff --git a/backend/build/lib/app/main.py b/backend/build/lib/app/main.py new file mode 100644 index 00000000..41f802c7 --- /dev/null +++ b/backend/build/lib/app/main.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import asyncio +import logging +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.api.connections import router as connections_router +from app.api.auth_routes import router as auth_router +from app.api.me import router as me_router +from app.api.projects import router as projects_router +from app.api.share import router as share_router +from app.api.snapshots import router as snapshots_router +from app.auth import try_get_subject_for_rate_limit +from app.csrf import CSRF_HEADER_NAME, generate_csrf_token, make_csrf_middleware +from app.db import SessionLocal, get_pooler_detection +from app.jobs.snapshot_job import handle_snapshot_job +from app.jobs.worker import run_worker_forever +from app.observability import setup_observability +from app.rate_limit import ( + InMemoryFixedWindowRateLimiter, + RateLimitPolicy, + make_rate_limit_middleware, +) +from app.security_headers import make_security_headers_middleware +from app.settings import settings + + +@asynccontextmanager +async def lifespan(_: FastAPI) -> AsyncIterator[None]: + """Run application startup/shutdown hooks. + + Starts a background job worker on startup and ensures it is cancelled and + awaited on shutdown. + """ + + handlers = {"snapshot": handle_snapshot_job} + task = asyncio.create_task(run_worker_forever(SessionLocal, handlers)) + try: + # Best-effort pooler detection (log once for ops visibility). + try: + detection = await get_pooler_detection() + logging.getLogger(__name__).info( + "db_pooler_detection: kind=%s detected=%s", + detection.kind.value, + detection.detected, + ) + except Exception: + logging.getLogger(__name__).exception("db_pooler_detection failed") + yield + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +app = FastAPI(title="pg-erd-cloud backend", lifespan=lifespan) + +CORS_ALLOW_HEADERS = [ + "Authorization", + "Content-Type", + CSRF_HEADER_NAME, +] + +_rate_limiter = InMemoryFixedWindowRateLimiter( + max_keys=settings.api_rate_limit_max_keys +) +_rate_limit_policy = RateLimitPolicy( + enabled=settings.api_rate_limit_enabled, + requests=settings.api_rate_limit_requests, + window_seconds=settings.api_rate_limit_window_seconds, + route_prefix="/api", + trust_x_forwarded_for=settings.api_rate_limit_trust_x_forwarded_for, +) +_share_link_rate_limiter = InMemoryFixedWindowRateLimiter( + max_keys=settings.share_link_rate_limit_max_keys +) +_share_link_rate_limit_policy = RateLimitPolicy( + enabled=settings.share_link_rate_limit_enabled, + requests=settings.share_link_rate_limit_requests, + window_seconds=settings.share_link_rate_limit_window_seconds, + route_prefix="/api/share", + trust_x_forwarded_for=settings.api_rate_limit_trust_x_forwarded_for, +) +_revoke_rate_limiter = InMemoryFixedWindowRateLimiter( + max_keys=settings.api_rate_limit_max_keys +) +_revoke_rate_limit_policy = RateLimitPolicy( + enabled=settings.api_rate_limit_enabled, + requests=10, + window_seconds=60, + route_prefix="/api/auth/logout", + trust_x_forwarded_for=settings.api_rate_limit_trust_x_forwarded_for, +) + +app.middleware("http")( + make_rate_limit_middleware( + limiter=_rate_limiter, + policy=_rate_limit_policy, + get_subject=try_get_subject_for_rate_limit, + ) +) +app.middleware("http")( + make_rate_limit_middleware( + limiter=_share_link_rate_limiter, + policy=_share_link_rate_limit_policy, + ) +) +app.middleware("http")( + make_rate_limit_middleware( + limiter=_revoke_rate_limiter, + policy=_revoke_rate_limit_policy, + get_subject=try_get_subject_for_rate_limit, + ) +) + +app.middleware("http")(make_csrf_middleware()) + +app.add_middleware( + CORSMiddleware, + allow_origins=[o.strip() for o in settings.cors_origins.split(",") if o.strip()], + # Default to the strictest safe setting. Enable credentials only when you + # actually need cookie-based auth. + allow_credentials=False, + # Explicit allowlist (avoid "*") so CORS behavior is reviewable. + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=CORS_ALLOW_HEADERS, +) + +# Observability should be registered after other middleware so it can capture +# early returns (e.g. 429, CORS preflight). +# +# Note: security headers are registered last (outermost) so headers are attached +# even when another middleware returns early. +setup_observability(app) + +# Apply response security headers. +# +# Starlette middleware order: the **last** registered middleware wraps earlier +# ones (i.e., it becomes the outermost). +# +# We register security headers last so headers are attached even when another +# middleware returns early (e.g., CORS preflight, 429 rate-limit responses). +# See: backend/tests/test_security_headers.py +app.middleware("http")(make_security_headers_middleware()) + + +@app.get("/healthz") +async def healthz() -> dict: + """Simple health-check endpoint.""" + return {"ok": True} + + +@app.get("/api/csrf-token") +async def csrf_token() -> dict[str, str]: + """Issue a signed token for unsafe API requests.""" + return {"csrf_token": generate_csrf_token(settings.app_secret)} + + +app.include_router(projects_router) +app.include_router(connections_router) +app.include_router(snapshots_router) +app.include_router(me_router) +app.include_router(share_router) +app.include_router(auth_router) diff --git a/backend/build/lib/app/metrics.py b/backend/build/lib/app/metrics.py new file mode 100644 index 00000000..282ab011 --- /dev/null +++ b/backend/build/lib/app/metrics.py @@ -0,0 +1,84 @@ +"""Prometheus metrics for pg-erd-cloud. + +This module intentionally keeps instrumentation lightweight and dependency +minimal. Metrics are exposed via /metrics (when enabled). +""" + +from __future__ import annotations + +from prometheus_client import ( + CONTENT_TYPE_LATEST, + Counter, + Histogram, + generate_latest, +) + + +def normalize_route_label(route: str) -> str: + """Normalize a route label to avoid high-cardinality metrics. + + This codebase intentionally records *route templates* (e.g. `/api/x/{id}`) + rather than raw request paths. + """ + if not route: + return "unmatched" + if route == "unmatched": + return route + if not route.startswith("/"): + return "unmatched" + return route + + +def prime_http_metrics(*, route_methods: dict[str, set[str]]) -> None: + """Create common label series so metrics show up before first traffic.""" + for route, methods in sorted(route_methods.items()): + normalized = normalize_route_label(route) + for method in sorted(methods): + # Counters need an explicit sample. + HTTP_REQUESTS_TOTAL.labels( + method=method, + route=normalized, + status="200", + ).inc(0) + # Histograms can be created without observing. + HTTP_REQUEST_DURATION_SECONDS.labels(method=method, route=normalized) + + +HTTP_REQUESTS_TOTAL = Counter( + "http_requests_total", + "Total number of HTTP responses by method/route/status.", + ["method", "route", "status"], +) + + +HTTP_REQUEST_DURATION_SECONDS = Histogram( + "http_request_duration_seconds", + "HTTP request duration in seconds by method/route.", + ["method", "route"], +) + + +JOB_QUEUE_JOBS_TOTAL = Counter( + "job_queue_jobs_total", + "Total number of job queue executions by type/outcome.", + ["job_type", "outcome"], +) + + +JOB_QUEUE_WAIT_SECONDS = Histogram( + "job_queue_wait_seconds", + "Time spent waiting in queue (started_at - run_after) in seconds.", + ["job_type"], +) + + +JOB_QUEUE_PROCESSING_SECONDS = Histogram( + "job_queue_processing_seconds", + "Time spent processing a job (handler runtime) in seconds.", + ["job_type", "outcome"], +) + + +def render_metrics() -> tuple[bytes, str]: + """Render all metrics in Prometheus exposition format.""" + return generate_latest(), CONTENT_TYPE_LATEST diff --git a/backend/build/lib/app/models.py b/backend/build/lib/app/models.py new file mode 100644 index 00000000..2a772ac3 --- /dev/null +++ b/backend/build/lib/app/models.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import datetime as dt +import uuid + +from sqlalchemy import DateTime, ForeignKey, Index, Integer, LargeBinary, Text +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + +class Base(DeclarativeBase): + """SQLAlchemy declarative base for ORM models.""" + + pass + + +def utcnow() -> dt.datetime: + """Return the current UTC timestamp (timezone-aware).""" + return dt.datetime.now(dt.timezone.utc) + + +class RevokedToken(Base): + """Persistent storage for revoked tokens to survive application restarts.""" + + __tablename__ = "revoked_token" + + jwt_id: Mapped[str] = mapped_column(Text(), primary_key=True) + expires_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True)) + + +class UserAccount(Base): + """User record keyed by a UUID and identified by OIDC subject.""" + + __tablename__ = "user_account" + + user_account_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + oidc_subject: Mapped[str] = mapped_column(Text(), unique=True) + display_name: Mapped[str | None] = mapped_column(Text(), nullable=True) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + + +class ProjectSpace(Base): + """Project container that groups connections and snapshots.""" + + __tablename__ = "project_space" + + project_space_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + project_name: Mapped[str] = mapped_column(Text()) + created_by_user_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("user_account.user_account_uuid") + ) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + + +class ProjectMember(Base): + """Membership mapping between users and projects with a role.""" + + __tablename__ = "project_member" + + project_space_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("project_space.project_space_uuid", ondelete="CASCADE"), + primary_key=True, + ) + user_account_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("user_account.user_account_uuid", ondelete="CASCADE"), + primary_key=True, + ) + project_role: Mapped[str] = mapped_column(Text()) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + + __table_args__ = ( + Index("ix_project_member__user_account_uuid", "user_account_uuid"), + ) + + +class DbConnection(Base): + """Encrypted database DSN belonging to a project.""" + + __tablename__ = "db_connection" + + db_connection_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + project_space_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("project_space.project_space_uuid", ondelete="CASCADE"), + index=True, + ) + conn_name: Mapped[str] = mapped_column(Text()) + dsn_ciphertext: Mapped[bytes] = mapped_column(LargeBinary()) + dsn_nonce: Mapped[bytes] = mapped_column(LargeBinary()) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + updated_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + + +class SchemaSnapshot(Base): + """Snapshot job record for a database introspection run.""" + + __tablename__ = "schema_snapshot" + + schema_snapshot_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + project_space_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("project_space.project_space_uuid", ondelete="CASCADE"), + index=True, + ) + db_connection_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("db_connection.db_connection_uuid", ondelete="CASCADE"), + ) + status: Mapped[str] = mapped_column(Text()) + schema_filter: Mapped[str | None] = mapped_column(Text(), nullable=True) + started_at: Mapped[dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + finished_at: Mapped[dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + error_message: Mapped[str | None] = mapped_column(Text(), nullable=True) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + + +class SchemaSnapshotData(Base): + """Captured schema snapshot JSON payload.""" + + __tablename__ = "schema_snapshot_data" + + schema_snapshot_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("schema_snapshot.schema_snapshot_uuid", ondelete="CASCADE"), + primary_key=True, + ) + snapshot_json: Mapped[dict] = mapped_column(JSONB()) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + + +class JobQueue(Base): + """Lightweight DB-backed job queue (MVP).""" + + __tablename__ = "job_queue" + + job_queue_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + job_type: Mapped[str] = mapped_column(Text()) + status: Mapped[str] = mapped_column(Text()) + payload_json: Mapped[dict] = mapped_column(JSONB()) + run_after: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + attempt_count: Mapped[int] = mapped_column(Integer(), default=0) + last_error: Mapped[str | None] = mapped_column(Text(), nullable=True) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) + started_at: Mapped[dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + finished_at: Mapped[dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + __table_args__ = (Index("ix_job_queue__status_run_after", "status", "run_after"),) + + +class ShareLink(Base): + """Public share link granting read access to a project's snapshots.""" + + __tablename__ = "share_link" + + share_link_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + project_space_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("project_space.project_space_uuid", ondelete="CASCADE"), + index=True, + ) + created_by_user_uuid: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("user_account.user_account_uuid", ondelete="CASCADE"), + ) + permission_kind: Mapped[str] = mapped_column(Text()) # viewer/editor (MVP: viewer) + expires_at: Mapped[dt.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + created_at: Mapped[dt.datetime] = mapped_column( + DateTime(timezone=True), default=utcnow + ) diff --git a/backend/build/lib/app/observability.py b/backend/build/lib/app/observability.py new file mode 100644 index 00000000..fcc189b3 --- /dev/null +++ b/backend/build/lib/app/observability.py @@ -0,0 +1,207 @@ +"""Observability: structured logs + metrics. + +MVP goals (issue #49): +- Emit JSON structured request logs suitable for central ingestion. +- Expose basic Prometheus metrics (/metrics) for API + job queue. +""" + +from __future__ import annotations + +import datetime as dt +import json +import logging +import re +import secrets +import time +import uuid +from collections.abc import Awaitable, Callable + +from fastapi import FastAPI, Response +from starlette.requests import Request + +from app.metrics import ( + HTTP_REQUEST_DURATION_SECONDS, + HTTP_REQUESTS_TOTAL, + normalize_route_label, + prime_http_metrics, + render_metrics, +) +from app.settings import settings + +_logger = logging.getLogger("app.observability") + +_SAFE_REQUEST_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + + +def _utc_now_iso() -> str: + return dt.datetime.now(dt.timezone.utc).isoformat() + + +def _get_route_template(request: Request) -> str: + route = request.scope.get("route") + path = getattr(route, "path", None) + if isinstance(path, str) and path: + return path + # Avoid high-cardinality labels for unmatched routes (e.g., 404 paths). + return "unmatched" + + +def _get_client_ip(request: Request) -> str: + if settings.api_rate_limit_trust_x_forwarded_for: + xff = request.headers.get("X-Forwarded-For") + if xff: + ip = xff.split(",")[-1].strip() + if ip: + return ip + + client = request.client + if client is None: + return "unknown" + return client.host or "unknown" + + +def _log_json(event: str, fields: dict[str, object], *, level: int) -> None: + payload: dict[str, object] = { + "ts": _utc_now_iso(), + "event": event, + **fields, + } + _logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":"))) + + +def _record_metrics_and_logs( + request: Request, + request_id: str, + status: int, + duration_s: float, + is_metrics_path: bool, +) -> None: + route = normalize_route_label(_get_route_template(request)) + + if settings.observability_metrics_enabled and not is_metrics_path: + HTTP_REQUESTS_TOTAL.labels( + method=request.method, + route=route, + status=str(status), + ).inc() + HTTP_REQUEST_DURATION_SECONDS.labels( + method=request.method, + route=route, + ).observe(duration_s) + + if settings.observability_request_logging_enabled and not is_metrics_path: + level = logging.INFO + if status >= 500: + level = logging.ERROR + elif status >= 400: + level = logging.WARNING + + _log_json( + "http_request", + { + "request_id": request_id, + "method": request.method, + "route": route, + "status": status, + "duration_ms": round(duration_s * 1000.0, 3), + "client_ip": _get_client_ip(request), + }, + level=level, + ) + + +def make_request_observability_middleware() -> Callable[ + [Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response] +]: + """Create request logging, metrics, and request-id middleware.""" + + async def middleware( + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Record one request and attach an X-Request-Id response header.""" + # Avoid recursive noise in logs/metrics. + is_metrics_path = request.url.path == "/metrics" + + raw_request_id = (request.headers.get("X-Request-Id") or "").strip() + if raw_request_id and _SAFE_REQUEST_ID_RE.fullmatch(raw_request_id): + request_id = raw_request_id + else: + request_id = str(uuid.uuid4()) + start = time.perf_counter() + + try: + response = await call_next(request) + except Exception: + duration_s = time.perf_counter() - start + _record_metrics_and_logs( + request, request_id, 500, duration_s, is_metrics_path + ) + + # Ensure the request id is visible to clients even on 500. + _logger.exception("unhandled request exception") + error_response = Response(status_code=500) + error_response.headers["X-Request-Id"] = request_id + return error_response + + duration_s = time.perf_counter() - start + status = int(response.status_code) + + response.headers["X-Request-Id"] = request_id + + _record_metrics_and_logs( + request, request_id, status, duration_s, is_metrics_path + ) + return response + + return middleware + + +def setup_observability(app: FastAPI) -> None: + """Register observability hooks on the given FastAPI app.""" + app.middleware("http")(make_request_observability_middleware()) + + if not settings.observability_metrics_enabled: + return + + token = (settings.observability_metrics_token or "").strip() + if not token: + _logger.warning( + "observability_metrics_enabled=true but token missing; " + "skipping /metrics route registration" + ) + return + + @app.get("/metrics", include_in_schema=False) + async def metrics(request: Request) -> Response: + """Return Prometheus metrics when the request presents the metrics token.""" + provided = request.headers.get("X-Metrics-Token") or "" + if not secrets.compare_digest(provided, token): + return Response(status_code=403) + content, content_type = render_metrics() + return Response(content=content, media_type=content_type) + + def _prime_metrics_on_startup() -> None: + route_methods: dict[str, set[str]] = {} + for r in app.routes: + path = getattr(r, "path", None) + if not isinstance(path, str) or not path or path == "/metrics": + continue + + r_methods = getattr(r, "methods", None) + if isinstance(r_methods, set): + methods = {m for m in r_methods if m not in {"HEAD"}} + else: + methods = set() + + if not methods: + methods.add("GET") + + if path in route_methods: + route_methods[path].update(methods) + else: + route_methods[path] = methods + + prime_http_metrics(route_methods=route_methods) + + app.router.add_event_handler("startup", _prime_metrics_on_startup) diff --git a/backend/build/lib/app/permissions.py b/backend/build/lib/app/permissions.py new file mode 100644 index 00000000..cd738894 --- /dev/null +++ b/backend/build/lib/app/permissions.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import uuid + +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import ProjectMember + +_ROLE_RANK = {"viewer": 0, "editor": 1, "owner": 2} + + +async def require_project_member( + session: AsyncSession, + project_space_uuid: uuid.UUID, + user_account_uuid: uuid.UUID, + minimum_role: str | None = None, +) -> str: + """Ensure the user is a project member with the required role.""" + row = await session.execute( + select(ProjectMember.project_role).where( + ProjectMember.project_space_uuid == project_space_uuid, + ProjectMember.user_account_uuid == user_account_uuid, + ) + ) + role = row.scalar_one_or_none() + if role is None: + raise HTTPException(status_code=403, detail="project access denied") + role_text = str(role) + if minimum_role is not None: + if _ROLE_RANK.get(role_text, -1) < _ROLE_RANK.get(minimum_role, 999): + raise HTTPException(status_code=403, detail="insufficient project role") + return role_text diff --git a/backend/build/lib/app/pg_introspect/__init__.py b/backend/build/lib/app/pg_introspect/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/build/lib/app/pg_introspect/column_examples.py b/backend/build/lib/app/pg_introspect/column_examples.py new file mode 100644 index 00000000..2f916ca6 --- /dev/null +++ b/backend/build/lib/app/pg_introspect/column_examples.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import re + + +def _text(value: object) -> str: + return value if isinstance(value, str) else "" + + +def _normalized(value: object) -> str: + return _text(value).strip().lower() + + +def _has_name(column_name: str, *patterns: str) -> bool: + return any(pattern in column_name for pattern in patterns) + + +def _matches_name(column_name: str, pattern: str) -> bool: + return re.search(pattern, column_name) is not None + + +def infer_column_example(column: dict) -> str: + """Infer a synthetic example value from column metadata. + + The value is intentionally generated from names and types only. Do not read + live table data here; snapshots are shareable and should not capture PII. + """ + + column_name = _normalized(column.get("column_name")) + data_type = _normalized(column.get("data_type")) + type_name = _normalized(column.get("type_name")) + type_category = _normalized(column.get("type_category")) + type_kind = _normalized(column.get("type_kind")) + combined_type = f"{data_type} {type_name}" + + if _has_name(column_name, "email", "e_mail") or column_name == "mail": + return "user@example.com" + if _has_name(column_name, "phone", "mobile", "tel_no", "telephone"): + return "+82-10-1234-5678" + if _has_name(column_name, "url", "uri", "website", "homepage"): + return "https://example.com/resource" + if _has_name(column_name, "currency"): + return "KRW" + if _has_name(column_name, "country"): + return "KR" + if _has_name(column_name, "locale", "language"): + return "ko-KR" + if _has_name(column_name, "status", "state"): + return "active" + if _has_name(column_name, "category", "kind", "type"): + return "standard" + if _has_name(column_name, "code"): + return "EXAMPLE_CODE" + if _has_name(column_name, "city"): + return "Seoul" + if _has_name(column_name, "address"): + return "123 Example St" + if _has_name(column_name, "name", "title"): + return "Example Name" + if _has_name(column_name, "description", "comment", "memo", "note"): + return "Example description" + + if "uuid" in combined_type or _matches_name(column_name, r"(^uuid$|_uuid$)"): + return "550e8400-e29b-41d4-a716-446655440000" + if _matches_name(column_name, r"(^id$|_id$|_uuid$)"): + if "char" in combined_type or "text" in combined_type: + return "ID-1001" + return "1001" + + if "bool" in combined_type or type_category == "b": + return "true" + if "timestamp" in combined_type or _has_name( + column_name, "created_at", "updated_at", "deleted_at" + ): + return "2026-01-15T09:30:00Z" + if "date" in combined_type: + return "2026-01-15" + if "time" in combined_type: + return "09:30:00" + if "json" in combined_type: + return '{"key":"value"}' + if "bytea" in combined_type: + return "base64:ZXhhbXBsZQ==" + if any(network_type in combined_type for network_type in ("inet", "cidr")): + return "192.0.2.10" + if "macaddr" in combined_type: + return "00:00:5e:00:53:01" + if "interval" in combined_type: + return "P1D" + if type_category == "a" or data_type.endswith("[]"): + return '["example"]' + if type_kind == "e": + return "example_value" + if any( + numeric_type in combined_type + for numeric_type in ( + "int", + "numeric", + "decimal", + "real", + "double", + "money", + ) + ): + if _has_name(column_name, "amount", "price", "cost", "total"): + return "123.45" + if _has_name(column_name, "rate", "ratio", "percent"): + return "0.15" + if _has_name(column_name, "count", "qty", "quantity"): + return "10" + return "123" + if any(text_type in combined_type for text_type in ("char", "text", "citext")): + return "example text" + + return "example" + + +def add_column_examples(columns: list[dict]) -> list[dict]: + # ⚡ Bolt: Mutate dictionaries in-place to avoid allocating new dicts for each column. + # Snapshot introspectors pass freshly built payload dictionaries, so callers do not share these objects. + for col in columns: + if "example_value" not in col: + col["example_value"] = infer_column_example(col) + if "example_value_source" not in col: + col["example_value_source"] = "generated" + return columns diff --git a/backend/build/lib/app/pg_introspect/dsn_guard.py b/backend/build/lib/app/pg_introspect/dsn_guard.py new file mode 100644 index 00000000..86865351 --- /dev/null +++ b/backend/build/lib/app/pg_introspect/dsn_guard.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +import ipaddress +import socket +from urllib.parse import parse_qsl, urlparse + +from app.settings import settings + +POSTGRES_SCHEMES = {"postgres", "postgresql"} + + +class DsnTargetError(ValueError): + """Raised when a PostgreSQL DSN points at a disallowed network target.""" + + +@dataclass(frozen=True) +class ValidatedDsnTarget: + """Connection target values that were checked for restricted IP ranges.""" + + hosts: tuple[str, ...] + port: int | None + + +def _configured_allowed_hosts() -> tuple[str, ...]: + return tuple( + item.strip().lower().rstrip(".") + for item in settings.db_introspection_allowed_hosts.split(",") + if item.strip() + ) + + +def _host_matches_allowed_entry(host: str, entry: str) -> bool: + if entry.startswith("*."): + suffix = entry[1:] + return host.endswith(suffix) and host != suffix.lstrip(".") + return host == entry + + +def _validate_allowed_host(host: str) -> None: + allowed_hosts = _configured_allowed_hosts() + if not allowed_hosts: + raise DsnTargetError("database host allowlist is not configured") + if any(_host_matches_allowed_entry(host, entry) for entry in allowed_hosts): + return + raise DsnTargetError("database host is not in the introspection allowlist") + + +def _is_restricted_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None: + ip = ip.ipv4_mapped + + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + or ip.is_unspecified + ) + + +def _parse_ip_literal( + host: str, +) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None: + try: + return ipaddress.ip_address(host.strip("[]")) + except ValueError: + return None + + +def _connection_host_for_ip( + ip: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> str: + if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None: + return str(ip.ipv4_mapped) + return str(ip) + + +async def _resolved_ips( + host: str, port: int | None +) -> set[ipaddress.IPv4Address | ipaddress.IPv6Address]: + try: + loop = asyncio.get_running_loop() + addrinfo = await loop.getaddrinfo(host, port or 5432, type=socket.SOCK_STREAM) + except socket.gaierror as err: + raise DsnTargetError("database host could not be resolved") from err + + resolved: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set() + for entry in addrinfo: + sockaddr = entry[4] + if not sockaddr: + continue + resolved.add(ipaddress.ip_address(sockaddr[0])) + if not resolved: + raise DsnTargetError("database host did not resolve to an IP address") + return resolved + + +def _parse_query_params(query: str) -> dict[str, list[str]]: + params: dict[str, list[str]] = {} + for key, value in parse_qsl(query, keep_blank_values=True): + params.setdefault(key.lower(), []).append(value) + return params + + +def _split_query_host_values(values: list[str], parameter: str) -> list[str]: + hosts: list[str] = [] + for value in values: + for host in value.split(","): + normalized = host.strip() + if not normalized: + raise DsnTargetError(f"database DSN query {parameter} is invalid") + hosts.append(normalized) + return hosts + + +def _validate_query_ports(values: list[str]) -> int | None: + if not values: + return None + + first_port: int | None = None + for value in values: + for port_value in value.split(","): + normalized = port_value.strip() + if not normalized: + raise DsnTargetError("database DSN query port is invalid") + try: + port = int(normalized) + except ValueError as err: + raise DsnTargetError("database DSN query port is invalid") from err + if port < 1 or port > 65535: + raise DsnTargetError("database DSN query port is invalid") + if first_port is None: + first_port = port + return first_port + + +def _unique_hosts(hosts: list[str]) -> tuple[str, ...]: + unique: list[str] = [] + seen: set[str] = set() + for host in hosts: + if host in seen: + continue + seen.add(host) + unique.append(host) + return tuple(unique) + + +async def _validated_ip_hosts( + h: str, is_hostaddr: bool, port: int | None +) -> tuple[str, ...]: + normalized = h.lower().rstrip(".") + + if normalized == "localhost" or normalized.endswith(".localhost"): + raise DsnTargetError("database host must not be localhost") + + _validate_allowed_host(normalized) + + literal_ip = _parse_ip_literal(normalized) + if literal_ip is not None: + if _is_restricted_ip(literal_ip): + raise DsnTargetError("database host resolves to a restricted IP range") + return (_connection_host_for_ip(literal_ip),) + + if is_hostaddr: + raise DsnTargetError("database DSN query hostaddr is invalid") + + resolved_ips = await _resolved_ips(normalized, port) + for ip in resolved_ips: + if _is_restricted_ip(ip): + raise DsnTargetError("database host resolves to a restricted IP range") + return tuple(_connection_host_for_ip(ip) for ip in sorted(resolved_ips, key=str)) + + +async def validate_postgres_dsn_target(dsn: str) -> ValidatedDsnTarget: + """Reject PostgreSQL DSNs that could target internal network resources.""" + + parsed = urlparse(dsn) + if parsed.scheme.lower() not in POSTGRES_SCHEMES: + raise DsnTargetError("database DSN must use postgres or postgresql scheme") + + host = parsed.hostname + if not host: + raise DsnTargetError("database DSN must include a host") + try: + port = parsed.port + except ValueError as err: + raise DsnTargetError("database DSN port is invalid") from err + query = _parse_query_params(parsed.query) + + port_override = _validate_query_ports(query.get("port", [])) + if port_override is not None: + port = port_override + + primary_hosts = await _validated_ip_hosts(host, False, port) + query_hosts = [] + for query_host in _split_query_host_values(query.get("host", []), "host"): + query_hosts.append(await _validated_ip_hosts(query_host, False, port)) + query_hostaddrs = [] + for query_hostaddr in _split_query_host_values( + query.get("hostaddr", []), "hostaddr" + ): + query_hostaddrs.append(await _validated_ip_hosts(query_hostaddr, True, port)) + + connection_host_groups = query_hosts + query_hostaddrs + if not connection_host_groups: + connection_host_groups = [primary_hosts] + + return ValidatedDsnTarget( + hosts=_unique_hosts( + [host for group in connection_host_groups for host in group] + ), + port=port, + ) diff --git a/backend/build/lib/app/pg_introspect/introspect.py b/backend/build/lib/app/pg_introspect/introspect.py new file mode 100644 index 00000000..4fb1ebc4 --- /dev/null +++ b/backend/build/lib/app/pg_introspect/introspect.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import datetime as dt + +import asyncpg + +from app.pg_introspect import queries +from app.pg_introspect.column_examples import add_column_examples +from app.pg_introspect.dsn_guard import validate_postgres_dsn_target +from app.sanitize import sanitize_for_storage + + +async def introspect_postgres(dsn: str, schema_filter: str | None) -> dict: + """Introspect a PostgreSQL database and return a snapshot JSON.""" + + # Note: avoid logging DSN. + target = await validate_postgres_dsn_target(dsn) + connect_host: str | list[str] = ( + target.hosts[0] if len(target.hosts) == 1 else list(target.hosts) + ) + if target.port is not None: + conn = await asyncpg.connect( + dsn, host=connect_host, port=target.port, timeout=10 + ) + else: + conn = await asyncpg.connect(dsn, host=connect_host, timeout=10) + try: + version = await conn.fetchval("SHOW server_version") + schema_name = schema_filter + include_system = False + + schemas = await conn.fetch(queries.SCHEMAS_SQL, schema_name, include_system) + relations = await conn.fetch(queries.RELATIONS_SQL, schema_name, include_system) + columns = await conn.fetch(queries.COLUMNS_SQL, schema_name, include_system) + constraints = await conn.fetch( + queries.CONSTRAINTS_SQL, schema_name, include_system + ) + indexes = await conn.fetch(queries.INDEXES_SQL, schema_name, include_system) + pk_columns = await conn.fetch( + queries.PK_COLUMNS_SQL, schema_name, include_system + ) + fk_edges = await conn.fetch(queries.FK_EDGES_SQL, schema_name, include_system) + citus_distributed_tables = [] + has_citus = await conn.fetchval( + "SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_extension WHERE extname = 'citus')" + ) + if has_citus: + try: + citus_distributed_tables = await conn.fetch( + queries.CITUS_DISTRIBUTED_TABLES_SQL, + schema_name, + include_system, + ) + except asyncpg.UndefinedTableError: + citus_distributed_tables = [] + + snapshot = { + "captured_at": dt.datetime.now(dt.timezone.utc).isoformat(), + "server_version": str(version), + "schema_filter": schema_filter, + "schemas": [dict(r) for r in schemas], + "relations": [dict(r) for r in relations], + "columns": add_column_examples([dict(r) for r in columns]), + "constraints": [dict(r) for r in constraints], + "indexes": [dict(r) for r in indexes], + "pk_columns": [dict(r) for r in pk_columns], + "fk_edges": [dict(r) for r in fk_edges], + "citus_distributed_tables": [dict(r) for r in citus_distributed_tables], + } + + return sanitize_for_storage(snapshot) # type: ignore[return-value] + finally: + await conn.close() diff --git a/backend/build/lib/app/pg_introspect/queries.py b/backend/build/lib/app/pg_introspect/queries.py new file mode 100644 index 00000000..c0b2d15b --- /dev/null +++ b/backend/build/lib/app/pg_introspect/queries.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +SCHEMAS_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +) +SELECT + n.oid AS schema_oid, + n.nspname AS schema_name +FROM pg_catalog.pg_namespace n +CROSS JOIN params p +WHERE + (p.schema_name IS NULL OR n.nspname = p.schema_name) + AND ( + p.include_system + OR ( + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname NOT LIKE 'pg_toast%' + AND n.nspname NOT LIKE 'pg_temp_%' + AND n.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +ORDER BY n.nspname; +""" + + +RELATIONS_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +) +SELECT + n.nspname AS schema_name, + c.oid AS relation_oid, + c.relname AS relation_name, + c.relkind::text AS relation_kind, + pg_catalog.obj_description(c.oid, 'pg_class') AS relation_comment, + c.relispartition AS is_partition, + pg_catalog.pg_get_partkeydef(c.oid) AS partition_key, + pg_catalog.pg_get_expr(c.relpartbound, c.oid) AS partition_bound, + parent.oid AS partition_parent_oid, + parent_ns.nspname AS partition_parent_schema, + parent.relname AS partition_parent_name, + rel_ts.spcname AS tablespace_name +FROM pg_catalog.pg_class c +JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace +LEFT JOIN pg_catalog.pg_inherits inh ON inh.inhrelid = c.oid +LEFT JOIN pg_catalog.pg_class parent ON parent.oid = inh.inhparent +LEFT JOIN pg_catalog.pg_namespace parent_ns ON parent_ns.oid = parent.relnamespace +LEFT JOIN pg_catalog.pg_tablespace rel_ts ON rel_ts.oid = c.reltablespace +CROSS JOIN params p +WHERE + c.relkind IN ('r','p','v','m') + AND (p.schema_name IS NULL OR n.nspname = p.schema_name) + AND ( + p.include_system + OR ( + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname NOT LIKE 'pg_toast%' + AND n.nspname NOT LIKE 'pg_temp_%' + AND n.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +ORDER BY n.nspname, c.relkind, c.relname; +""" + + +COLUMNS_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +) +SELECT + n.nspname AS schema_name, + c.oid AS relation_oid, + c.relname AS relation_name, + c.relkind::text AS relation_kind, + a.attnum AS column_position, + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + typ.oid AS type_oid, + typ_ns.nspname AS type_schema, + typ.typname AS type_name, + typ.typtype::text AS type_kind, + typ.typcategory::text AS type_category, + CASE + WHEN typ.typtype = 'd' + THEN pg_catalog.format_type(typ.typbasetype, typ.typtypmod) + ELSE NULL + END AS domain_base_type, + base_typ_ns.nspname AS domain_base_schema, + base_typ.typname AS domain_base_name, + CASE + WHEN typ.typcategory = 'A' + THEN pg_catalog.format_type(typ.typelem, -1) + ELSE NULL + END AS array_element_type, + elem_typ_ns.nspname AS array_element_schema, + elem_typ.typname AS array_element_name, + a.attndims AS array_dimensions, + a.attnotnull AS is_not_null, + a.atthasdef AS has_default, + pg_catalog.pg_get_expr(ad.adbin, ad.adrelid) AS default_expr, + pg_catalog.col_description(a.attrelid, a.attnum) AS column_comment +FROM pg_catalog.pg_attribute a +JOIN pg_catalog.pg_class c ON c.oid = a.attrelid +JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace +JOIN pg_catalog.pg_type typ ON typ.oid = a.atttypid +JOIN pg_catalog.pg_namespace typ_ns ON typ_ns.oid = typ.typnamespace +LEFT JOIN pg_catalog.pg_type base_typ + ON base_typ.oid = typ.typbasetype AND typ.typtype = 'd' +LEFT JOIN pg_catalog.pg_namespace base_typ_ns + ON base_typ_ns.oid = base_typ.typnamespace +LEFT JOIN pg_catalog.pg_type elem_typ + ON elem_typ.oid = typ.typelem AND typ.typcategory = 'A' +LEFT JOIN pg_catalog.pg_namespace elem_typ_ns + ON elem_typ_ns.oid = elem_typ.typnamespace +LEFT JOIN pg_catalog.pg_attrdef ad ON ad.adrelid = a.attrelid AND ad.adnum = a.attnum +CROSS JOIN params p +WHERE + c.relkind IN ('r','p','v','m') + AND a.attnum > 0 + AND NOT a.attisdropped + AND (p.schema_name IS NULL OR n.nspname = p.schema_name) + AND ( + p.include_system + OR ( + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname NOT LIKE 'pg_toast%' + AND n.nspname NOT LIKE 'pg_temp_%' + AND n.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +ORDER BY n.nspname, c.relname, a.attnum; +""" + + +CONSTRAINTS_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +) +SELECT + con.oid AS constraint_oid, + con.conname AS constraint_name, + con.contype::text AS constraint_type, + n.nspname AS schema_name, + rel.oid AS relation_oid, + rel.relname AS relation_name, + frel.oid AS foreign_relation_oid, + fn.nspname AS foreign_schema_name, + frel.relname AS foreign_relation_name, + con.conkey AS constrained_attnums, + con.confkey AS referenced_attnums, + con.confupdtype::text AS fk_on_update, + con.confdeltype::text AS fk_on_delete, + con.confmatchtype::text AS fk_match_type, + pg_catalog.pg_get_constraintdef(con.oid, true) AS constraint_def, + pg_catalog.pg_get_expr(con.conbin, con.conrelid) AS check_expr +FROM pg_catalog.pg_constraint con +JOIN pg_catalog.pg_class rel ON rel.oid = con.conrelid +JOIN pg_catalog.pg_namespace n ON n.oid = rel.relnamespace +LEFT JOIN pg_catalog.pg_class frel ON frel.oid = con.confrelid +LEFT JOIN pg_catalog.pg_namespace fn ON fn.oid = frel.relnamespace +CROSS JOIN params p +WHERE + con.contype IN ('p','f','u','c') + AND (p.schema_name IS NULL OR n.nspname = p.schema_name) + AND ( + p.include_system + OR ( + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname NOT LIKE 'pg_toast%' + AND n.nspname NOT LIKE 'pg_temp_%' + AND n.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +ORDER BY n.nspname, rel.relname, con.conname; +""" + + +INDEXES_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +) +SELECT + idx.oid AS index_oid, + idx_ns.nspname AS index_schema_name, + idx.relname AS index_name, + + tbl.oid AS relation_oid, + tbl.oid AS table_oid, + tbl_ns.nspname AS table_schema_name, + tbl.relname AS table_name, + idx_ts.spcname AS index_tablespace_name, + + am.amname AS access_method, + am_ext.extname AS access_method_extension, + COALESCE(opclasses.operator_classes, ARRAY[]::text[]) AS operator_classes, + COALESCE(opclasses.operator_class_extensions, ARRAY[]::text[]) AS operator_class_extensions, + ix.indisunique AS is_unique, + ix.indisprimary AS is_primary, + ix.indisvalid AS is_valid, + + pg_catalog.pg_get_expr(ix.indpred, ix.indrelid) AS predicate_expr, + pg_catalog.pg_get_indexdef(idx.oid) AS index_def +FROM pg_catalog.pg_index ix +JOIN pg_catalog.pg_class idx ON idx.oid = ix.indexrelid +JOIN pg_catalog.pg_namespace idx_ns ON idx_ns.oid = idx.relnamespace +JOIN pg_catalog.pg_class tbl ON tbl.oid = ix.indrelid +JOIN pg_catalog.pg_namespace tbl_ns ON tbl_ns.oid = tbl.relnamespace +JOIN pg_catalog.pg_am am ON am.oid = idx.relam +LEFT JOIN pg_catalog.pg_tablespace idx_ts ON idx_ts.oid = idx.reltablespace +LEFT JOIN LATERAL ( + SELECT ext.extname + FROM pg_catalog.pg_depend dep + JOIN pg_catalog.pg_extension ext ON ext.oid = dep.refobjid + WHERE + dep.classid = 'pg_catalog.pg_am'::regclass + AND dep.objid = am.oid + AND dep.refclassid = 'pg_catalog.pg_extension'::regclass + AND dep.deptype = 'e' + ORDER BY ext.extname + LIMIT 1 +) am_ext ON true +LEFT JOIN LATERAL ( + SELECT + array_agg(opc_ns.nspname || '.' || opc.opcname ORDER BY cls.ordinality) AS operator_classes, + COALESCE( + array_agg(DISTINCT ext.extname ORDER BY ext.extname) FILTER (WHERE ext.extname IS NOT NULL), + ARRAY[]::text[] + ) AS operator_class_extensions + FROM unnest(ix.indclass) WITH ORDINALITY AS cls(opclass_oid, ordinality) + JOIN pg_catalog.pg_opclass opc ON opc.oid = cls.opclass_oid + JOIN pg_catalog.pg_namespace opc_ns ON opc_ns.oid = opc.opcnamespace + LEFT JOIN pg_catalog.pg_depend dep + ON dep.classid = 'pg_catalog.pg_opclass'::regclass + AND dep.objid = opc.oid + AND dep.refclassid = 'pg_catalog.pg_extension'::regclass + AND dep.deptype = 'e' + LEFT JOIN pg_catalog.pg_extension ext ON ext.oid = dep.refobjid +) opclasses ON true +CROSS JOIN params p +WHERE + tbl.relkind IN ('r','p','m') + AND (p.schema_name IS NULL OR tbl_ns.nspname = p.schema_name) + AND ( + p.include_system + OR ( + tbl_ns.nspname NOT IN ('pg_catalog', 'information_schema') + AND tbl_ns.nspname NOT LIKE 'pg_toast%' + AND tbl_ns.nspname NOT LIKE 'pg_temp_%' + AND tbl_ns.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +ORDER BY tbl_ns.nspname, tbl.relname, idx.relname; +""" + + +CITUS_DISTRIBUTED_TABLES_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +) +SELECT + tbl.oid AS relation_oid, + ns.nspname AS schema_name, + tbl.relname AS relation_name, + p.partmethod::text AS distribution_method, + p.partkey::text AS distribution_key, + p.colocationid AS colocation_id, + p.repmodel::text AS replication_model, + co.shardcount AS configured_shard_count, + co.replicationfactor AS replication_factor, + COUNT(sh.shardid)::int AS shard_count +FROM pg_catalog.pg_dist_partition p +JOIN pg_catalog.pg_class tbl ON tbl.oid = p.logicalrelid +JOIN pg_catalog.pg_namespace ns ON ns.oid = tbl.relnamespace +LEFT JOIN pg_catalog.pg_dist_colocation co ON co.colocationid = p.colocationid +LEFT JOIN pg_catalog.pg_dist_shard sh ON sh.logicalrelid = p.logicalrelid +CROSS JOIN params prm +WHERE + (prm.schema_name IS NULL OR ns.nspname = prm.schema_name) + AND ( + prm.include_system + OR ( + ns.nspname NOT IN ('pg_catalog', 'information_schema') + AND ns.nspname NOT LIKE 'pg_toast%' + AND ns.nspname NOT LIKE 'pg_temp_%' + AND ns.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +GROUP BY + tbl.oid, + ns.nspname, + tbl.relname, + p.partmethod, + p.partkey, + p.colocationid, + p.repmodel, + co.shardcount, + co.replicationfactor +ORDER BY ns.nspname, tbl.relname; +""" + + +PK_COLUMNS_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +), pk AS ( + SELECT con.* + FROM pg_catalog.pg_constraint con + JOIN pg_catalog.pg_class rel ON rel.oid = con.conrelid + JOIN pg_catalog.pg_namespace n ON n.oid = rel.relnamespace + CROSS JOIN params p + WHERE + con.contype = 'p' + AND (p.schema_name IS NULL OR n.nspname = p.schema_name) + AND ( + p.include_system + OR ( + n.nspname NOT IN ('pg_catalog', 'information_schema') + AND n.nspname NOT LIKE 'pg_toast%' + AND n.nspname NOT LIKE 'pg_temp_%' + AND n.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +) +SELECT + con.oid AS constraint_oid, + con.conname AS constraint_name, + n.nspname AS schema_name, + rel.oid AS relation_oid, + rel.relname AS relation_name, + k.ordinality AS column_ordinal, + a.attname AS column_name +FROM pk con +JOIN pg_catalog.pg_class rel ON rel.oid = con.conrelid +JOIN pg_catalog.pg_namespace n ON n.oid = rel.relnamespace +JOIN LATERAL unnest(con.conkey) WITH ORDINALITY AS k(attnum, ordinality) ON true +JOIN pg_catalog.pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = k.attnum +ORDER BY n.nspname, rel.relname, con.conname, k.ordinality; +""" + + +FK_EDGES_SQL = """ +WITH params AS ( + SELECT $1::text AS schema_name, COALESCE($2::boolean, false) AS include_system +), fk AS ( + SELECT con.* + FROM pg_catalog.pg_constraint con + JOIN pg_catalog.pg_class child_rel ON child_rel.oid = con.conrelid + JOIN pg_catalog.pg_namespace child_ns ON child_ns.oid = child_rel.relnamespace + CROSS JOIN params p + WHERE + con.contype = 'f' + AND (p.schema_name IS NULL OR child_ns.nspname = p.schema_name) + AND ( + p.include_system + OR ( + child_ns.nspname NOT IN ('pg_catalog', 'information_schema') + AND child_ns.nspname NOT LIKE 'pg_toast%' + AND child_ns.nspname NOT LIKE 'pg_temp_%' + AND child_ns.nspname NOT LIKE 'pg_toast_temp_%' + ) + ) +) +SELECT + con.oid AS fk_constraint_oid, + con.conname AS fk_constraint_name, + + child_ns.nspname AS child_schema_name, + child_rel.oid AS child_relation_oid, + child_rel.relname AS child_relation_name, + + parent_ns.nspname AS parent_schema_name, + parent_rel.oid AS parent_relation_oid, + parent_rel.relname AS parent_relation_name, + + map.ordinality AS column_ordinal, + child_att.attname AS child_column_name, + parent_att.attname AS parent_column_name, + + con.confupdtype::text AS fk_on_update, + con.confdeltype::text AS fk_on_delete, + con.confmatchtype::text AS fk_match_type +FROM fk con +JOIN pg_catalog.pg_class child_rel ON child_rel.oid = con.conrelid +JOIN pg_catalog.pg_namespace child_ns ON child_ns.oid = child_rel.relnamespace +JOIN pg_catalog.pg_class parent_rel ON parent_rel.oid = con.confrelid +JOIN pg_catalog.pg_namespace parent_ns ON parent_ns.oid = parent_rel.relnamespace +JOIN LATERAL ( + SELECT ck.attnum AS child_attnum, fk.attnum AS parent_attnum, ck.ord AS ordinality + FROM unnest(con.conkey) WITH ORDINALITY AS ck(attnum, ord) + JOIN unnest(con.confkey) WITH ORDINALITY AS fk(attnum, ord) USING (ord) +) AS map ON true +JOIN pg_catalog.pg_attribute child_att + ON child_att.attrelid = con.conrelid AND child_att.attnum = map.child_attnum +JOIN pg_catalog.pg_attribute parent_att + ON parent_att.attrelid = con.confrelid AND parent_att.attnum = map.parent_attnum +ORDER BY child_schema_name, child_relation_name, fk_constraint_name, column_ordinal; +""" diff --git a/backend/build/lib/app/pooler.py b/backend/build/lib/app/pooler.py new file mode 100644 index 00000000..703f2ba3 --- /dev/null +++ b/backend/build/lib/app/pooler.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Literal + +from sqlalchemy.engine import URL, make_url + + +class PoolerKind(str, Enum): + """Known PostgreSQL connection poolers.""" + + PGBOUNCER = "pgbouncer" + PGCAT = "pgcat" + UNKNOWN = "unknown" + NONE = "none" + + +ReadRoutingMode = Literal["off", "auto", "on"] + + +@dataclass(frozen=True) +class PoolerDetectionResult: + """Best-effort pooler detection output.""" + + kind: PoolerKind + detected: bool + version_text: str | None + + +def classify_pooler_version_text(version_text: str) -> PoolerKind: + """Classify a pooler based on SHOW VERSION output.""" + text = version_text.strip().lower() + if "pgbouncer" in text: + return PoolerKind.PGBOUNCER + if "pgcat" in text: + return PoolerKind.PGCAT + return PoolerKind.UNKNOWN + + +def build_admin_console_dsn( + database_url: str, admin_database: str +) -> tuple[str, str | None]: + """Build a sync DSN (without password) for pooler admin consoles. + + Pooler admin consoles are typically exposed as virtual databases such as + `pgbouncer` and `pgcat`. This helper rewrites an async SQLAlchemy URL into a + sync DSN for psycopg, and sets the database name to the admin DB. + """ + + url = make_url(database_url) + + # psycopg expects a regular PostgreSQL URL; strip any SQLAlchemy driver. + drivername = url.drivername + if drivername.startswith("postgresql+"): + drivername = "postgresql" + + # Avoid embedding credentials in DSN strings. Some drivers/loggers may echo + # DSNs, so keep the password separate. + password = url.password + + # NOTE: URL.set(password=None) does not clear the password; it leaves it as + # is. Construct a fresh URL to ensure the password is omitted. + safe_url = URL.create( + drivername=drivername, + username=url.username, + password=None, + host=url.host, + port=url.port, + database=admin_database, + query=url.query, + ) + + # Render as a string with password redaction enabled (defense-in-depth). + dsn = str(safe_url.render_as_string(hide_password=True)) + return dsn, password + + +def should_route_reads_to_read_only( + *, + mode: ReadRoutingMode, + read_only_url: str | None, + pooler_detected: bool, +) -> bool: + """Decide whether a read-only session should use the read-only DSN.""" + + if not read_only_url: + return False + + if mode == "off": + return False + if mode == "on": + return True + # mode == "auto" + return pooler_detected diff --git a/backend/build/lib/app/rate_limit.py b/backend/build/lib/app/rate_limit.py new file mode 100644 index 00000000..b025382c --- /dev/null +++ b/backend/build/lib/app/rate_limit.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +import asyncio +import logging +import math +import time +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +SubjectGetter = Callable[[Request], Awaitable[str | None]] + + +_logger = logging.getLogger(__name__) +_UNKNOWN_IP_LOG_THROTTLE_SECONDS = 60.0 +_last_unknown_ip_log_at = 0.0 + + +@dataclass(frozen=True) +class RateLimitPolicy: + """A small, dependency-free fixed-window rate limit policy. + + Notes: + - This is intentionally in-memory (per-process). + - In multi-worker/multi-instance deployments, limits are enforced per worker + unless an external shared store (Redis/Valkey) is introduced. + """ + + enabled: bool + requests: int + window_seconds: float + route_prefix: str = "/api" + trust_x_forwarded_for: bool = False + + +def _get_client_ip(request: Request, *, trust_x_forwarded_for: bool) -> str: + global _last_unknown_ip_log_at + + if trust_x_forwarded_for: + xff = request.headers.get("X-Forwarded-For") + if xff: + # Use the right-most value (nearest trusted proxy), trimming whitespace. + # This is safer than the left-most which can be spoofed by the client. + ip = xff.split(",")[-1].strip() + if ip: + return ip + + client = request.client + if client is None: + ip = "unknown" + else: + ip = client.host or "unknown" + + # Avoid silently aggregating many callers under an "unknown" key. + if ip == "unknown": + now = time.monotonic() + if now - _last_unknown_ip_log_at >= _UNKNOWN_IP_LOG_THROTTLE_SECONDS: + _last_unknown_ip_log_at = now + _logger.warning( + "rate_limit: unable to resolve client IP; falling back to 'unknown'" + " (path=%s, x_forwarded_for_present=%s)", + request.url.path, + bool(request.headers.get("X-Forwarded-For")), + ) + return ip + + +class InMemoryFixedWindowRateLimiter: + """A minimal in-memory fixed-window rate limiter.""" + + def __init__(self, *, max_keys: int = 10_000) -> None: + if max_keys <= 0: + raise ValueError("max_keys must be positive") + self._max_keys = max_keys + self._lock = asyncio.Lock() + # key -> (window_id, count) + self._buckets: dict[str, tuple[int, int]] = {} + + async def hit(self, *, key: str, policy: RateLimitPolicy) -> tuple[bool, int]: + """Record a hit and return (allowed, retry_after_seconds).""" + if policy.window_seconds <= 0: + # Treat as disabled to avoid division by zero. + return True, 0 + if policy.requests <= 0: + # Always reject when configured to 0. + return False, math.ceil(policy.window_seconds) + + now = time.monotonic() + window_id = int(now // policy.window_seconds) + retry_after = int( + max(0.0, math.ceil((window_id + 1) * policy.window_seconds - now)) + ) + + async with self._lock: + # Best-effort eviction to cap memory. + # Prefer removing expired windows first, then evict oldest entries. + if key not in self._buckets and len(self._buckets) >= self._max_keys: + expired_keys = [ + k for k, (wid, _) in self._buckets.items() if wid != window_id + ] + for k in expired_keys: + self._buckets.pop(k, None) + + while key not in self._buckets and len(self._buckets) >= self._max_keys: + oldest_key = next(iter(self._buckets)) + self._buckets.pop(oldest_key, None) + + prev = self._buckets.get(key) + if prev is None or prev[0] != window_id: + count = 1 + else: + count = prev[1] + 1 + + # Move key to the end to approximate LRU behavior. + self._buckets.pop(key, None) + self._buckets[key] = (window_id, count) + allowed = count <= policy.requests + return allowed, retry_after + + +def make_rate_limit_middleware( + *, + limiter: InMemoryFixedWindowRateLimiter, + policy: RateLimitPolicy, + get_subject: SubjectGetter | None = None, +) -> Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]]: + """Create a FastAPI/Starlette http middleware implementing rate limiting.""" + + async def middleware( + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Apply the configured fixed-window limit to matching requests.""" + if not policy.enabled: + return await call_next(request) + + path = request.url.path + if not path.startswith(policy.route_prefix): + return await call_next(request) + + subject: str | None = None + if get_subject is not None: + try: + subject = await get_subject(request) + except Exception: # noqa: BLE001 + # Never fail requests due to key derivation. + subject = None + + ip = _get_client_ip(request, trust_x_forwarded_for=policy.trust_x_forwarded_for) + key = f"ip:{ip}" + if subject: + key = f"{key}|sub:{subject}" + + allowed, retry_after = await limiter.hit(key=key, policy=policy) + if allowed: + return await call_next(request) + + return JSONResponse( + {"detail": "rate limit exceeded"}, + status_code=429, + headers={"Retry-After": str(retry_after)}, + ) + + return middleware diff --git a/backend/build/lib/app/redact.py b/backend/build/lib/app/redact.py new file mode 100644 index 00000000..ce92e419 --- /dev/null +++ b/backend/build/lib/app/redact.py @@ -0,0 +1,24 @@ +import copy +from typing import Any, Dict + + +def redact_sensitive_schema_data(snapshot_json: Dict[str, Any] | None) -> Dict[str, Any] | None: + """Redacts sensitive properties like comments or example values from snapshot data.""" + if not snapshot_json: + return snapshot_json + + redacted = copy.deepcopy(snapshot_json) + + if "tables" in redacted and isinstance(redacted["tables"], list): + for table in redacted["tables"]: + if "comment" in table and table["comment"] is not None: + table["comment"] = "[REDACTED]" + + if "columns" in table and isinstance(table["columns"], list): + for column in table["columns"]: + if "column_comment" in column and column["column_comment"] is not None: + column["column_comment"] = "[REDACTED]" + if "example_value" in column and column["example_value"] is not None: + column["example_value"] = "[REDACTED]" + + return redacted diff --git a/backend/build/lib/app/sanitize.py b/backend/build/lib/app/sanitize.py new file mode 100644 index 00000000..3c009b63 --- /dev/null +++ b/backend/build/lib/app/sanitize.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import base64 +from collections.abc import Mapping + + +def strip_nul(value: str) -> str: + """Remove NUL (0x00) characters from a string.""" + + # PostgreSQL text/json rejects NUL(0x00). Remove it. + return value.replace("\x00", "") + + +def sanitize_for_storage(obj: object) -> object: + """Recursively sanitize strings for DB storage. + + - Removes NUL chars from *all* strings. + - Converts bytes/memoryview to safe text (best-effort UTF-8; fallback base64). + """ + + if obj is None: + return obj + if isinstance(obj, str): + return strip_nul(obj) + if isinstance(obj, memoryview): + obj = obj.tobytes() + if isinstance(obj, (bytes, bytearray)): + try: + return strip_nul(bytes(obj).decode("utf-8")) + except Exception: # noqa: BLE001 + return base64.b64encode(bytes(obj)).decode("ascii") + if isinstance(obj, list): + return [sanitize_for_storage(v) for v in obj] + if isinstance(obj, tuple): + return tuple(sanitize_for_storage(v) for v in obj) + if isinstance(obj, Mapping): + return {strip_nul(str(k)): sanitize_for_storage(v) for k, v in obj.items()} + return obj diff --git a/backend/build/lib/app/schemas.py b/backend/build/lib/app/schemas.py new file mode 100644 index 00000000..0359799c --- /dev/null +++ b/backend/build/lib/app/schemas.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import uuid +from typing import Literal + +from pydantic import BaseModel, Field + + +class ProjectCreateIn(BaseModel): + """Request body for creating a project.""" + + project_name: str = Field(min_length=1, max_length=255) + + +class ProjectOut(BaseModel): + """Project summary returned by API.""" + + project_space_uuid: uuid.UUID + project_name: str + + +class ProjectMemberAddIn(BaseModel): + """Request body for inviting/adding a project member.""" + + member_subject: str = Field( + min_length=1, + max_length=128, + pattern=r"^[^\s\x00-\x1F\x7F]+$", + description="OIDC sub, or dev: in dev mode", + ) + # MVP: restrict to non-owner roles. Owner is assigned at project creation. + project_role: Literal["viewer", "editor"] = Field(default="viewer") + + +class ProjectMemberOut(BaseModel): + """Project member representation returned by API.""" + + user_account_uuid: uuid.UUID + member_subject: str + project_role: str + + +class ConnectionCreateIn(BaseModel): + """Request body for creating a DB connection.""" + + conn_name: str = Field(min_length=1, max_length=128) + dsn: str = Field( + min_length=1, + max_length=4096, + description=("PostgreSQL or Snowflake connection string. Not logged."), + ) + + +class ConnectionOut(BaseModel): + """Connection summary returned by API.""" + + db_connection_uuid: uuid.UUID + conn_name: str + + +class SnapshotCreateIn(BaseModel): + """Request body for creating a schema snapshot.""" + + db_connection_uuid: uuid.UUID + schema_filter: str | None = Field( + default=None, + description=( + "If set, only introspect this schema (unquoted database identifier)" + ), + min_length=1, + max_length=63, + pattern=r"^[A-Za-z_][A-Za-z0-9_$]{0,62}$", + ) + + +class SnapshotOut(BaseModel): + """Snapshot summary returned by API.""" + + schema_snapshot_uuid: uuid.UUID + status: str + schema_filter: str | None + + +class SnapshotDetailOut(BaseModel): + """Snapshot detail returned by API.""" + + schema_snapshot_uuid: uuid.UUID + status: str + schema_filter: str | None + error_message: str | None + snapshot_json: dict | None + + +class MeOut(BaseModel): + """Current user payload returned by /me.""" + + user_account_uuid: uuid.UUID + subject: str + display_name: str | None diff --git a/backend/build/lib/app/security.py b/backend/build/lib/app/security.py new file mode 100644 index 00000000..e81f2f59 --- /dev/null +++ b/backend/build/lib/app/security.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import hashlib +from dataclasses import dataclass + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from app.settings import settings + + +def _derive_key() -> bytes: + """Derive a stable 32-byte key from APP_SECRET (MVP). + + In production, prefer KMS/HKDF with rotation. + """ + + # MVP key derivation: stable 32-bytes from APP_SECRET. + # In production prefer KMS/HKDF with rotation. + return hashlib.sha256(settings.app_secret.encode("utf-8")).digest() + + +@dataclass(frozen=True) +class EncryptedBlob: + """Encrypted bytes plus nonce for AES-GCM.""" + + ciphertext: bytes + nonce: bytes + + +def encrypt_text(plaintext: str) -> EncryptedBlob: + """Encrypt a UTF-8 string using AES-256-GCM.""" + key = _derive_key() + aes = AESGCM(key) + import os + + nonce = os.urandom(12) + ciphertext = aes.encrypt(nonce, plaintext.encode("utf-8"), None) + return EncryptedBlob(ciphertext=ciphertext, nonce=nonce) + + +def decrypt_text(ciphertext: bytes, nonce: bytes) -> str: + """Decrypt a blob produced by encrypt_text.""" + key = _derive_key() + aes = AESGCM(key) + plaintext = aes.decrypt(nonce, ciphertext, None) + return plaintext.decode("utf-8") + + +def redact_dsn(dsn: str) -> str: + """Redact credentials from a DSN for safe logging.""" + + # Avoid leaking credentials in logs. + # Best-effort: remove password in typical URI formats. + # If unsure, return a constant to avoid partial leaks. + if "@" not in dsn or "://" not in dsn: + return "***" + scheme, rest = dsn.split("://", 1) + if "@" not in rest: + return "***" + _, hostpart = rest.rsplit("@", 1) + return f"{scheme}://***@{hostpart}" diff --git a/backend/build/lib/app/security_headers.py b/backend/build/lib/app/security_headers.py new file mode 100644 index 00000000..af5d3123 --- /dev/null +++ b/backend/build/lib/app/security_headers.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable + +from starlette.requests import Request +from starlette.responses import Response + +_DOCS_PREFIXES: tuple[str, ...] = ( + "/docs", + "/docs/oauth2-redirect", + "/redoc", + "/openapi.json", +) + + +def _is_https(request: Request) -> bool: + """Return True when the request URL scheme is HTTPS. + + Security note: we intentionally do not trust X-Forwarded-Proto here. + """ + # Do not trust X-Forwarded-Proto here. + # If you need HSTS behind a proxy, set it at the ingress/reverse-proxy. + return request.url.scheme.lower() == "https" + + +def _should_apply_csp(request: Request) -> bool: + """Return True when CSP should be applied for this request.""" + # Be robust to non-canonical paths (e.g. //docs, /DOCS). + raw_path = request.url.path + normalized = "/" + raw_path.lstrip("/") + path = normalized.lower() + return not any(path.startswith(p.lower()) for p in _DOCS_PREFIXES) + + +def apply_security_headers(request: Request, response: Response) -> None: + """Apply baseline response hardening headers. + + Notes: + - Prefer setting headers at the ingress/reverse-proxy for production. + - We also apply them in-app as a fallback so local/dev/test runs match + production expectations. + - CSP is intentionally not applied to FastAPI docs endpoints by default to + avoid breaking Swagger UI. + """ + + def _set_if_missing(name: str, value: str) -> None: + """Set a response header only when not already present.""" + if name not in response.headers: + response.headers[name] = value + + _set_if_missing("X-Content-Type-Options", "nosniff") + _set_if_missing("X-Frame-Options", "DENY") + _set_if_missing("Referrer-Policy", "no-referrer") + _set_if_missing( + "Permissions-Policy", + "geolocation=(), microphone=(), camera=()", + ) + + if _should_apply_csp(request): + # API-first default. We only need a minimal policy for JSON responses. + # (Swagger UI is excluded by _DOCS_PREFIXES.) + _set_if_missing( + "Content-Security-Policy", + "default-src 'none'; base-uri 'none'; frame-ancestors 'none'; " + "form-action 'none'", + ) + + if _is_https(request): + # HSTS is only meaningful over HTTPS. + _set_if_missing( + "Strict-Transport-Security", + "max-age=31536000; includeSubDomains", + ) + + +def make_security_headers_middleware() -> Callable[ + [Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response] +]: + """Create a Starlette/FastAPI http middleware applying response headers.""" + + async def middleware( + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + """ASGI middleware that attaches baseline response headers.""" + response = await call_next(request) + apply_security_headers(request, response) + return response + + return middleware diff --git a/backend/build/lib/app/settings.py b/backend/build/lib/app/settings.py new file mode 100644 index 00000000..0d1a7b91 --- /dev/null +++ b/backend/build/lib/app/settings.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +from pydantic import Field +from pydantic import model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + + database_url: str + # Optional: a read-only endpoint/replica DSN. + database_read_only_url: str | None = None + + # Optional: pooler kind hint. If set, probing is skipped. + db_pooler_kind: Literal["pgbouncer", "pgcat", "unknown", "none"] | None = None + + # Read routing mode. In auto mode, read-only DSN is used only when a pooler + # is detected (or hinted via db_pooler_kind). + db_read_routing: Literal["off", "auto", "on"] = "auto" + + # Best-effort pooler probe timeout. Keep it small to avoid blocking request + # paths. + db_pooler_probe_timeout_seconds: float = Field(0.7, ge=0.0) + # Required encryption key material. + # + # Supports the Docker/Podman *_FILE pattern (e.g. /run/secrets/app_secret) + # to avoid putting secrets directly into environment variables. + app_secret: str + app_secret_file: str | None = Field( + default=None, validation_alias="APP_SECRET_FILE" + ) + + @model_validator(mode="before") + @classmethod + def _load_app_secret_from_file(cls, data: object) -> object: + if not isinstance(data, dict): + return data + + secret_file = data.get("APP_SECRET_FILE") or data.get("app_secret_file") + if not secret_file: + return data + + path = Path(str(secret_file)) + try: + resolved = path.resolve(strict=True) + except FileNotFoundError as exc: + raise ValueError(f"APP_SECRET_FILE does not exist: {path}") from exc + + # Security hardening: avoid symlink tricks and restrict secret files to + # the expected secrets mount. + # + # Docker/Podman secrets are typically mounted under /run/secrets. + allowed_base = Path("/run/secrets").resolve() + if path.is_symlink(): + raise ValueError("APP_SECRET_FILE must not be a symlink") + if not resolved.is_relative_to(allowed_base): + raise ValueError( + f"APP_SECRET_FILE must be under {allowed_base}: {resolved}" + ) + if not resolved.is_file(): + raise ValueError( + f"APP_SECRET_FILE does not exist or is not a file: {resolved}" + ) + + # Important: secret files commonly include a trailing newline. + secret = resolved.read_text(encoding="utf-8").rstrip("\r\n") + if secret == "": + raise ValueError("APP_SECRET_FILE is empty") + + # If APP_SECRET_FILE is provided, prefer it deterministically. + new_data = dict(data) + new_data["app_secret"] = secret + return new_data + + cors_origins: str = "http://localhost:5173" + + # Rate limiting (FastAPI app layer, applied to /api/*) + api_rate_limit_enabled: bool = True + api_rate_limit_requests: int = Field(120, ge=1) + api_rate_limit_window_seconds: float = Field(60.0, gt=0.0) + api_rate_limit_trust_x_forwarded_for: bool = False + api_rate_limit_max_keys: int = Field(10_000, ge=1) + share_link_rate_limit_enabled: bool = True + share_link_rate_limit_requests: int = Field(30, ge=1) + share_link_rate_limit_window_seconds: float = Field(60.0, gt=0.0) + share_link_rate_limit_max_keys: int = Field(10_000, ge=1) + + # Observability (MVP) + observability_request_logging_enabled: bool = True + # Metrics exposure must be opt-in. + observability_metrics_enabled: bool = False + # Optional shared token for /metrics when OIDC isn't configured. + observability_metrics_token: str | None = None + + # Optional Valkey-backed queue signal path. The relational job_queue table + # remains the source of truth; Valkey reduces polling/lock pressure by + # carrying due job IDs for workers to claim. + job_queue_backend: Literal["database", "valkey"] = "database" + valkey_url: str | None = None + valkey_sentinel_hosts: str | None = None + valkey_sentinel_master: str | None = None + valkey_queue_key: str = "pg-erd-cloud:job-queue" + valkey_lock_ttl_seconds: int = Field(300, ge=1) + + # Optional OIDC (Casdoor). If set, JWTs are verified. + oidc_issuer: str | None = None + oidc_audience: str | None = None + + # Optional allowlist for reverse-engineering database targets. + # Comma-separated exact hostnames/IPs or wildcard domains like *.example.com. + db_introspection_allowed_hosts: str = "" + + # Optional OpenAI-compatible chat-completions provider for live reversing + # spec drafts. Leave unset to keep all reversing spec generation local. + llm_api_base_url: str | None = None + llm_api_key: str | None = None + llm_model: str | None = None + llm_timeout_seconds: float = Field(30.0, gt=0.0, le=120.0) + + # Allowed JWT signing algorithms for OIDC verification. + # Comma-separated string (env: OIDC_ALGORITHMS). Default is RS256. + # NOTE: Do not trust the token header's alg; only accept algorithms from + # this allowlist. + oidc_algorithms: str = "RS256" + + +settings = Settings() # type: ignore[call-arg] diff --git a/backend/build/lib/app/snowflake_introspect/__init__.py b/backend/build/lib/app/snowflake_introspect/__init__.py new file mode 100644 index 00000000..8b3fd867 --- /dev/null +++ b/backend/build/lib/app/snowflake_introspect/__init__.py @@ -0,0 +1,3 @@ +from app.snowflake_introspect.introspect import introspect_snowflake + +__all__ = ["introspect_snowflake"] diff --git a/backend/build/lib/app/snowflake_introspect/introspect.py b/backend/build/lib/app/snowflake_introspect/introspect.py new file mode 100644 index 00000000..dd0e3767 --- /dev/null +++ b/backend/build/lib/app/snowflake_introspect/introspect.py @@ -0,0 +1,711 @@ +from __future__ import annotations + +import asyncio +import datetime as dt +import importlib +import re +from collections import defaultdict +from dataclasses import dataclass +from typing import Any +from urllib.parse import parse_qsl, unquote, urlparse + +from app.pg_introspect.column_examples import add_column_examples +from app.sanitize import sanitize_for_storage +from app.pg_introspect.dsn_guard import _validated_ip_hosts + +SCHEMAS_SQL = """ +SELECT schema_name +FROM information_schema.schemata +WHERE catalog_name = CURRENT_DATABASE() + AND schema_name <> 'INFORMATION_SCHEMA' + AND (%s IS NULL OR schema_name = %s) +ORDER BY schema_name +""" + +TABLES_SQL = """ +SELECT table_schema, table_name, table_type, comment +FROM information_schema.tables +WHERE table_schema <> 'INFORMATION_SCHEMA' + AND (%s IS NULL OR table_schema = %s) +ORDER BY table_schema, table_name +""" + +COLUMNS_SQL = """ +SELECT + table_schema, + table_name, + ordinal_position, + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale, + datetime_precision, + is_nullable, + column_default, + comment +FROM information_schema.columns +WHERE table_schema <> 'INFORMATION_SCHEMA' + AND (%s IS NULL OR table_schema = %s) +ORDER BY table_schema, table_name, ordinal_position +""" + +CONSTRAINT_COLUMNS_SQL = """ +SELECT + tc.constraint_schema, + tc.constraint_name, + tc.constraint_type, + tc.table_schema, + tc.table_name, + kcu.column_name, + kcu.ordinal_position, + pk_kcu.table_schema AS referenced_table_schema, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name +FROM information_schema.table_constraints tc +LEFT JOIN information_schema.key_column_usage kcu + ON kcu.constraint_catalog = tc.constraint_catalog + AND kcu.constraint_schema = tc.constraint_schema + AND kcu.constraint_name = tc.constraint_name +LEFT JOIN information_schema.referential_constraints rc + ON rc.constraint_catalog = tc.constraint_catalog + AND rc.constraint_schema = tc.constraint_schema + AND rc.constraint_name = tc.constraint_name +LEFT JOIN information_schema.key_column_usage pk_kcu + ON pk_kcu.constraint_catalog = rc.unique_constraint_catalog + AND pk_kcu.constraint_schema = rc.unique_constraint_schema + AND pk_kcu.constraint_name = rc.unique_constraint_name + AND pk_kcu.ordinal_position = kcu.position_in_unique_constraint +WHERE tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE', 'FOREIGN KEY') + AND (%s IS NULL OR tc.table_schema = %s) +ORDER BY + tc.table_schema, + tc.table_name, + tc.constraint_name, + kcu.ordinal_position +""" + +VERSION_SQL = "SELECT CURRENT_VERSION() AS server_version" + +SUPPORTED_QUERY_PARAMS = {"warehouse", "role", "authenticator"} + + +@dataclass(frozen=True) +class SnowflakeDsnConfig: + """Connection settings parsed from a Snowflake DSN.""" + + account: str + user: str + password: str | None + database: str + schema: str | None + warehouse: str | None + role: str | None + authenticator: str | None + + def connect_kwargs(self) -> dict[str, str]: + """Return non-empty keyword arguments accepted by the Snowflake connector.""" + kwargs = { + "account": self.account, + "user": self.user, + "database": self.database, + } + optional = { + "password": self.password, + "schema": self.schema, + "warehouse": self.warehouse, + "role": self.role, + "authenticator": self.authenticator, + } + for key, value in optional.items(): + if value: + kwargs[key] = value + return kwargs + + +def _int_or_none(value: object) -> int | None: + if isinstance(value, int): + return value + if isinstance(value, str) and value.isdigit(): + return int(value) + return None + + +def _str_or_none(value: object) -> str | None: + return value if isinstance(value, str) and value else None + + +async def _parse_snowflake_dsn(dsn: str) -> SnowflakeDsnConfig: + parsed = urlparse(dsn) + if parsed.scheme.lower() != "snowflake": + raise ValueError("Snowflake DSN must use the snowflake scheme") + if not parsed.hostname: + raise ValueError("Snowflake DSN must include an account identifier") + if not parsed.username: + raise ValueError("Snowflake DSN must include a user") + + path_parts = [unquote(part) for part in parsed.path.split("/") if part] + if not path_parts: + raise ValueError("Snowflake DSN must include a database path segment") + if len(path_parts) > 2: + raise ValueError("Snowflake DSN path must be /database or /database/schema") + + query: dict[str, str] = {} + for key, value in parse_qsl(parsed.query, keep_blank_values=True): + normalized = key.lower() + if normalized not in SUPPORTED_QUERY_PARAMS: + raise ValueError(f"unsupported Snowflake DSN query parameter: {key}") + if not value: + raise ValueError(f"Snowflake DSN query parameter is blank: {key}") + + if normalized == "authenticator": + # Prevent SSRF: only allow known safe authenticator values or Okta URLs + auth_lower = value.lower() + safe_auths = { + "snowflake", + "snowflake_jwt", + "externalbrowser", + "oauth", + "username_password_mfa", + } + if auth_lower not in safe_auths: + if not auth_lower.startswith("https://"): + raise ValueError("unsupported Snowflake authenticator value") + parsed_auth = urlparse(auth_lower) + if not parsed_auth.hostname or not re.match( + r"^([a-zA-Z0-9-]+\.)*(okta|oktapreview)\.com$", parsed_auth.hostname + ): + raise ValueError("unsupported Snowflake authenticator URL") + + query[normalized] = value + + await _validated_ip_hosts(parsed.hostname, is_hostaddr=False, port=443) + + return SnowflakeDsnConfig( + account=parsed.hostname, + user=unquote(parsed.username), + password=unquote(parsed.password) if parsed.password else None, + database=path_parts[0], + schema=path_parts[1] if len(path_parts) == 2 else None, + warehouse=query.get("warehouse"), + role=query.get("role"), + authenticator=query.get("authenticator"), + ) + + +def _connect(**kwargs: str) -> Any: + try: + connector = importlib.import_module("snowflake.connector") + except ImportError as exc: + raise RuntimeError( + "Snowflake reverse engineering requires the optional " + "snowflake-connector-python package" + ) from exc + return connector.connect(**kwargs) + + +def _fetch_dicts(cursor: Any, sql: str, params: tuple[object, ...] = ()) -> list[dict]: + cursor.execute(sql, params) + rows = cursor.fetchall() + if not rows: + return [] + if isinstance(rows[0], dict): + return [{str(key).lower(): value for key, value in row.items()} for row in rows] + + columns = [str(description[0]).lower() for description in cursor.description] + return [dict(zip(columns, row)) for row in rows] + + +def _snowflake_relation_kind(table_type: object) -> str: + normalized = str(table_type or "").upper() + if "VIEW" in normalized: + return "m" if "MATERIALIZED" in normalized else "v" + return "r" + + +def _format_snowflake_data_type(row: dict) -> str: + data_type = str(row.get("data_type") or "VARCHAR").upper() + precision = _int_or_none(row.get("numeric_precision")) + scale = _int_or_none(row.get("numeric_scale")) + char_length = _int_or_none(row.get("character_maximum_length")) + datetime_precision = _int_or_none(row.get("datetime_precision")) + + if data_type in ("NUMBER", "NUMERIC", "DECIMAL") and precision is not None: + return f"{data_type}({precision},{scale or 0})" + if data_type in ("VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"): + if char_length is not None and char_length > 0: + return f"{data_type}({char_length})" + return data_type + if data_type.startswith("TIMESTAMP") and datetime_precision is not None: + return f"{data_type}({datetime_precision})" + return data_type + + +def _table_key(row: dict) -> tuple[str, str]: + return (str(row.get("table_schema") or ""), str(row.get("table_name") or "")) + + +def _q(ident: str) -> str: + return '"' + ident.replace('"', '""') + '"' + + +def _constraint_type(value: object) -> str | None: + normalized = str(value or "").upper() + if normalized == "PRIMARY KEY": + return "p" + if normalized == "UNIQUE": + return "u" + if normalized == "FOREIGN KEY": + return "f" + return None + + +def _constraint_def( + constraint_type: str, + columns: list[str], + referenced_schema: str | None, + referenced_table: str | None, + referenced_columns: list[str], +) -> str: + quoted_cols = ", ".join(_q(col) for col in columns) + if constraint_type == "p": + return f"PRIMARY KEY ({quoted_cols})" + if constraint_type == "u": + return f"UNIQUE ({quoted_cols})" + if referenced_schema and referenced_table and referenced_columns: + quoted_ref_cols = ", ".join(_q(col) for col in referenced_columns) + return ( + f"FOREIGN KEY ({quoted_cols}) REFERENCES " + f"{_q(referenced_schema)}.{_q(referenced_table)} ({quoted_ref_cols})" + ) + return f"FOREIGN KEY ({quoted_cols})" + + +def _build_primary_key( + name: str, + schema: str, + table: str, + relation_oid: int, + columns: list[str], + constrained_attnums: list[int], + constraint_oid: int, +) -> tuple[dict, list[dict]]: + constraint = { + "constraint_oid": constraint_oid, + "constraint_name": name, + "constraint_type": "p", + "schema_name": schema, + "relation_oid": relation_oid, + "relation_name": table, + "foreign_relation_oid": None, + "foreign_schema_name": None, + "foreign_relation_name": None, + "constrained_attnums": constrained_attnums, + "referenced_attnums": [], + "constraint_def": _constraint_def("p", columns, None, None, []), + "check_expr": None, + } + + pk_columns = [] + for ordinal, column in enumerate(columns, start=1): + pk_columns.append( + { + "constraint_oid": constraint_oid, + "constraint_name": name, + "schema_name": schema, + "relation_oid": relation_oid, + "relation_name": table, + "column_ordinal": ordinal, + "column_name": column, + } + ) + return constraint, pk_columns + + +def _build_unique_constraint( + name: str, + schema: str, + table: str, + relation_oid: int, + columns: list[str], + constrained_attnums: list[int], + constraint_oid: int, +) -> dict: + return { + "constraint_oid": constraint_oid, + "constraint_name": name, + "constraint_type": "u", + "schema_name": schema, + "relation_oid": relation_oid, + "relation_name": table, + "foreign_relation_oid": None, + "foreign_schema_name": None, + "foreign_relation_name": None, + "constrained_attnums": constrained_attnums, + "referenced_attnums": [], + "constraint_def": _constraint_def("u", columns, None, None, []), + "check_expr": None, + } + + +def _build_foreign_key( + name: str, + schema: str, + table: str, + relation_oid: int, + columns: list[str], + constrained_attnums: list[int], + constraint_oid: int, + referenced_schema: str | None, + referenced_table: str | None, + referenced_columns: list[str], + foreign_relation_oid: int | None, + sorted_rows: list[dict], +) -> tuple[dict, list[dict]]: + constraint = { + "constraint_oid": constraint_oid, + "constraint_name": name, + "constraint_type": "f", + "schema_name": schema, + "relation_oid": relation_oid, + "relation_name": table, + "foreign_relation_oid": foreign_relation_oid, + "foreign_schema_name": referenced_schema, + "foreign_relation_name": referenced_table, + "constrained_attnums": constrained_attnums, + "referenced_attnums": [], + "constraint_def": _constraint_def( + "f", + columns, + referenced_schema, + referenced_table, + referenced_columns, + ), + "check_expr": None, + } + + fk_edges = [] + if referenced_schema and referenced_table: + for ordinal, row in enumerate(sorted_rows, start=1): + child_column = _str_or_none(row.get("column_name")) + parent_column = _str_or_none(row.get("referenced_column_name")) + if not (child_column and parent_column): + continue + fk_edges.append( + { + "fk_constraint_oid": constraint_oid, + "fk_constraint_name": name, + "child_schema_name": schema, + "child_relation_oid": relation_oid, + "child_relation_name": table, + "parent_schema_name": referenced_schema, + "parent_relation_oid": foreign_relation_oid, + "parent_relation_name": referenced_table, + "column_ordinal": ordinal, + "child_column_name": child_column, + "parent_column_name": parent_column, + "fk_on_update": None, + "fk_on_delete": None, + "fk_match_type": None, + } + ) + + return constraint, fk_edges + + +def _group_constraint_rows( + rows: list[dict], +) -> dict[tuple[str, str, str, str], list[dict]]: + grouped: dict[tuple[str, str, str, str], list[dict]] = defaultdict(list) + for row in rows: + ctype = _constraint_type(row.get("constraint_type")) + schema = _str_or_none(row.get("table_schema")) + table = _str_or_none(row.get("table_name")) + name = _str_or_none(row.get("constraint_name")) + constraint_schema = _str_or_none(row.get("constraint_schema")) or schema + if not (ctype and schema and table and name and constraint_schema): + continue + grouped[(constraint_schema, name, schema, table)].append(row) + return grouped + + +def _process_constraint_group( + group_rows: list[dict], + name: str, + schema: str, + table: str, + relation_ids: dict[tuple[str, str], int], + column_positions: dict[tuple[str, str], dict[str, int]], + constraint_oid: int, +) -> tuple[dict | None, list[dict], list[dict]]: + sorted_rows = sorted( + group_rows, + key=lambda row: int(row.get("ordinal_position") or 0), + ) + ctype = _constraint_type(sorted_rows[0].get("constraint_type")) + relation_oid = relation_ids.get((schema, table)) + if relation_oid is None: + return None, [], [] + + columns = [ + str(row["column_name"]) + for row in sorted_rows + if isinstance(row.get("column_name"), str) + ] + attnums = [ + column_positions.get((schema, table), {}).get(column) for column in columns + ] + constrained_attnums = [attnum for attnum in attnums if isinstance(attnum, int)] + referenced_schema = _str_or_none(sorted_rows[0].get("referenced_table_schema")) + referenced_table = _str_or_none(sorted_rows[0].get("referenced_table_name")) + referenced_columns = [ + str(row["referenced_column_name"]) + for row in sorted_rows + if isinstance(row.get("referenced_column_name"), str) + ] + foreign_relation_oid = ( + relation_ids.get((referenced_schema, referenced_table)) + if referenced_schema and referenced_table + else None + ) + + if ctype == "p": + constraint, new_pk_columns = _build_primary_key( + name, + schema, + table, + relation_oid, + columns, + constrained_attnums, + constraint_oid, + ) + return constraint, new_pk_columns, [] + elif ctype == "u": + constraint = _build_unique_constraint( + name, + schema, + table, + relation_oid, + columns, + constrained_attnums, + constraint_oid, + ) + return constraint, [], [] + elif ctype == "f": + constraint, new_fk_edges = _build_foreign_key( + name, + schema, + table, + relation_oid, + columns, + constrained_attnums, + constraint_oid, + referenced_schema, + referenced_table, + referenced_columns, + foreign_relation_oid, + sorted_rows, + ) + return constraint, [], new_fk_edges + + return None, [], [] + + +def _build_constraints( + rows: list[dict], + relation_ids: dict[tuple[str, str], int], + column_positions: dict[tuple[str, str], dict[str, int]], +) -> tuple[list[dict], list[dict], list[dict]]: + grouped = _group_constraint_rows(rows) + + constraints: list[dict] = [] + pk_columns: list[dict] = [] + fk_edges: list[dict] = [] + + for (_, name, schema, table), group_rows in grouped.items(): + constraint_oid = len(constraints) + 1 + constraint, new_pk_columns, new_fk_edges = _process_constraint_group( + group_rows, + name, + schema, + table, + relation_ids, + column_positions, + constraint_oid, + ) + if constraint: + constraints.append(constraint) + pk_columns.extend(new_pk_columns) + fk_edges.extend(new_fk_edges) + + return constraints, pk_columns, fk_edges + + +def _build_schemas(schema_rows: list[dict]) -> list[dict]: + return [ + { + "schema_oid": index, + "schema_name": str(row.get("schema_name")), + } + for index, row in enumerate(schema_rows, start=1) + if isinstance(row.get("schema_name"), str) + ] + + +def _build_relations( + table_rows: list[dict], + relation_keys: list[tuple[str, str]], + relation_ids: dict[tuple[str, str], int], +) -> list[dict]: + relations = [] + table_row_by_key = {_table_key(row): row for row in table_rows} + for schema, table in relation_keys: + row = table_row_by_key.get((schema, table), {}) + relations.append( + { + "schema_name": schema, + "relation_oid": relation_ids[(schema, table)], + "relation_name": table, + "relation_kind": _snowflake_relation_kind(row.get("table_type")), + "relation_comment": row.get("comment"), + "is_partition": False, + "partition_key": None, + "partition_bound": None, + "partition_parent_oid": None, + "partition_parent_schema": None, + "partition_parent_name": None, + "tablespace_name": None, + } + ) + return relations + + +def _build_columns( + column_rows: list[dict], + relation_ids: dict[tuple[str, str], int], + column_positions: dict[tuple[str, str], dict[str, int]], +) -> list[dict]: + columns = [] + for row in column_rows: + schema, table = _table_key(row) + relation_oid = relation_ids.get((schema, table)) + if relation_oid is None: + continue + column_name = _str_or_none(row.get("column_name")) + position = _int_or_none(row.get("ordinal_position")) + if not (column_name and position is not None): + continue + column_positions[(schema, table)][column_name] = position + data_type = _format_snowflake_data_type(row) + columns.append( + { + "schema_name": schema, + "relation_oid": relation_oid, + "relation_name": table, + "relation_kind": "r", + "column_position": position, + "column_name": column_name, + "data_type": data_type, + "type_oid": None, + "type_schema": "INFORMATION_SCHEMA", + "type_name": str(row.get("data_type") or ""), + "type_kind": None, + "type_category": None, + "domain_base_type": None, + "domain_base_schema": None, + "domain_base_name": None, + "array_element_type": None, + "array_element_schema": None, + "array_element_name": None, + "array_dimensions": 0, + "is_not_null": str(row.get("is_nullable") or "").upper() == "NO", + "has_default": row.get("column_default") is not None, + "default_expr": row.get("column_default"), + "column_comment": row.get("comment"), + } + ) + return columns + + +def _build_snapshot( + config: SnowflakeDsnConfig, + effective_schema: str | None, + version_rows: list[dict], + schema_rows: list[dict], + table_rows: list[dict], + column_rows: list[dict], + constraint_rows: list[dict], +) -> dict: + relation_keys = sorted({_table_key(row) for row in table_rows}) + relation_ids = {key: index for index, key in enumerate(relation_keys, start=1)} + column_positions: dict[tuple[str, str], dict[str, int]] = defaultdict(dict) + + schemas = _build_schemas(schema_rows) + relations = _build_relations(table_rows, relation_keys, relation_ids) + columns = _build_columns(column_rows, relation_ids, column_positions) + + constraints, pk_columns, fk_edges = _build_constraints( + constraint_rows, relation_ids, column_positions + ) + server_version = ( + str(version_rows[0].get("server_version")) + if version_rows and version_rows[0].get("server_version") is not None + else "snowflake" + ) + snapshot = { + "source_dialect": "snowflake", + "database_dialect": "snowflake", + "captured_at": dt.datetime.now(dt.timezone.utc).isoformat(), + "server_version": server_version, + "database_name": config.database, + "schema_filter": effective_schema, + "schemas": schemas, + "relations": relations, + "columns": add_column_examples(columns), + "constraints": constraints, + "indexes": [], + "pk_columns": pk_columns, + "fk_edges": fk_edges, + } + return sanitize_for_storage(snapshot) # type: ignore[return-value] + + +def _introspect_snowflake_sync_with_config( + config: SnowflakeDsnConfig, + schema_filter: str | None, +) -> dict: + effective_schema = schema_filter or config.schema + query_params = (effective_schema, effective_schema) + + conn = _connect(**config.connect_kwargs()) + cursor = conn.cursor() + try: + version_rows = _fetch_dicts(cursor, VERSION_SQL) + schema_rows = _fetch_dicts(cursor, SCHEMAS_SQL, query_params) + table_rows = _fetch_dicts(cursor, TABLES_SQL, query_params) + column_rows = _fetch_dicts(cursor, COLUMNS_SQL, query_params) + constraint_rows = _fetch_dicts(cursor, CONSTRAINT_COLUMNS_SQL, query_params) + finally: + try: + cursor.close() + finally: + conn.close() + + return _build_snapshot( + config, + effective_schema, + version_rows, + schema_rows, + table_rows, + column_rows, + constraint_rows, + ) + + +async def introspect_snowflake(dsn: str, schema_filter: str | None) -> dict: + """Introspect Snowflake metadata into the common snapshot JSON shape.""" + + config = await _parse_snowflake_dsn(dsn) + return await asyncio.to_thread( + _introspect_snowflake_sync_with_config, config, schema_filter + ) diff --git a/backend/build/lib/app/spec/__init__.py b/backend/build/lib/app/spec/__init__.py new file mode 100644 index 00000000..83871d30 --- /dev/null +++ b/backend/build/lib/app/spec/__init__.py @@ -0,0 +1 @@ +"""Specification generation helpers.""" diff --git a/backend/build/lib/app/spec/index_design.py b/backend/build/lib/app/spec/index_design.py new file mode 100644 index 00000000..a6aac55a --- /dev/null +++ b/backend/build/lib/app/spec/index_design.py @@ -0,0 +1,394 @@ +from __future__ import annotations + +import json +import re +from collections import defaultdict +from typing import Literal + +from app.jobs.valkey_queue import valkey_queue_config_summary + + +SpecMode = Literal["markdown", "llm-prompt"] +MAX_IDENTIFIER_LENGTH = 63 + + +def _text(value: object, default: str = "") -> str: + return value if isinstance(value, str) else default + + +def _rows(snapshot: dict, key: str) -> list[dict]: + value = snapshot.get(key) + if not isinstance(value, list): + return [] + return [row for row in value if isinstance(row, dict)] + + +def _q(identifier: str) -> str: + return '"' + identifier.replace('"', '""') + '"' + + +def _qname(schema: str, name: str) -> str: + return f"{_q(schema)}.{_q(name)}" + + +def _identifier_part(value: str) -> str: + chars = [ch.lower() if ch.isalnum() else "_" for ch in value] + text = "".join(chars).strip("_") + while "__" in text: + text = text.replace("__", "_") + return text or "col" + + +def _index_name(table_name: str, columns: list[str]) -> str: + raw = f"idx_{_identifier_part(table_name)}_{'_'.join(_identifier_part(c) for c in columns)}" + return raw[:MAX_IDENTIFIER_LENGTH] + + +def _escape_cell(value: object) -> str: + text = "-" if value is None else str(value) + return text.replace("|", "\\|").replace("\n", " ") + + +def _relation_label(row: dict) -> str: + schema = _text(row.get("schema_name"), "unknown") + name = _text(row.get("relation_name"), "unknown") + return f"{schema}.{name}" + + +def _relation_maps(snapshot: dict) -> tuple[dict[int, dict], dict[str, int]]: + by_oid: dict[int, dict] = {} + oid_by_label: dict[str, int] = {} + for relation in _rows(snapshot, "relations"): + oid = relation.get("relation_oid") + if not isinstance(oid, int): + continue + by_oid[oid] = relation + oid_by_label[_relation_label(relation)] = oid + return by_oid, oid_by_label + + +def _columns_by_relation(snapshot: dict) -> dict[int, list[dict]]: + grouped: dict[int, list[dict]] = defaultdict(list) + for column in _rows(snapshot, "columns"): + oid = column.get("relation_oid") + if isinstance(oid, int): + grouped[oid].append(column) + return grouped + + +def _indexes_by_relation(snapshot: dict) -> dict[int, list[dict]]: + grouped: dict[int, list[dict]] = defaultdict(list) + for index in _rows(snapshot, "indexes"): + oid = index.get("relation_oid") or index.get("table_oid") + if isinstance(oid, int): + grouped[oid].append(index) + return grouped + + +def _citus_by_relation(snapshot: dict) -> dict[int, dict]: + result: dict[int, dict] = {} + for row in _rows(snapshot, "citus_distributed_tables"): + oid = row.get("relation_oid") + if isinstance(oid, int): + result[oid] = row + return result + + +def _fk_groups(snapshot: dict) -> list[dict]: + groups: dict[tuple[str, int], list[dict]] = defaultdict(list) + for edge in _rows(snapshot, "fk_edges"): + constraint = _text(edge.get("fk_constraint_name"), "fk") + child_oid = edge.get("child_relation_oid") + if isinstance(child_oid, int): + groups[(constraint, child_oid)].append(edge) + + result = [] + for (constraint, child_oid), edges in groups.items(): + ordered = sorted(edges, key=lambda x: int(x.get("column_ordinal") or 0)) + columns = [ + _text(edge.get("child_column_name")) + for edge in ordered + if _text(edge.get("child_column_name")) + ] + if columns: + result.append( + { + "constraint": constraint, + "child_relation_oid": child_oid, + "columns": columns, + } + ) + return result + + +def _existing_index_column_text(index: dict) -> str: + index_def = index.get("index_def") + if not isinstance(index_def, str): + return "" + marker = " USING " + if marker not in index_def: + return index_def.lower() + return index_def.split(marker, 1)[1].lower() + + +def _index_mentions_column(index_text: str, column: str) -> bool: + quoted = _q(column).lower() + if quoted in index_text: + return True + return re.search(rf"\b{re.escape(column.lower())}\b", index_text) is not None + + +def _has_covering_index( + indexes: list[dict], + columns: list[str], +) -> bool: + for index in indexes: + if index.get("is_primary") is True: + continue + column_text = _existing_index_column_text(index) + if columns and all( + _index_mentions_column(column_text, column) for column in columns + ): + return True + return False + + +def _candidate_indexes(snapshot: dict) -> list[dict]: + relations_by_oid, _ = _relation_maps(snapshot) + indexes_by_oid = _indexes_by_relation(snapshot) + candidates = [] + for group in _fk_groups(snapshot): + child_oid = group["child_relation_oid"] + relation = relations_by_oid.get(child_oid) + if relation is None: + continue + columns = group["columns"] + if _has_covering_index(indexes_by_oid.get(child_oid, []), columns): + continue + schema = _text(relation.get("schema_name"), "public") + table = _text(relation.get("relation_name"), "unknown") + index_name = _index_name(table, columns) + columns_sql = ", ".join(_q(column) for column in columns) + candidates.append( + { + "index_name": index_name, + "schema_name": schema, + "table_name": table, + "columns": columns, + "reason": (f"foreign-key lookup support for {group['constraint']}"), + "ddl": ( + f"CREATE INDEX CONCURRENTLY {_q(index_name)} " + f"ON {_qname(schema, table)} USING btree ({columns_sql});" + ), + } + ) + return candidates + + +def _workload_observations(snapshot: dict) -> list[object]: + observations = [] + for key in ( + "explain_analyze", + "explain_results", + "query_plans", + "workload_observations", + ): + value = snapshot.get(key) + if isinstance(value, list): + observations.extend(value) + elif isinstance(value, (str, dict)): + observations.append(value) + return observations + + +def _compact_index_design_summary(snapshot: dict) -> dict: + relations_by_oid, _ = _relation_maps(snapshot) + columns_by_oid = _columns_by_relation(snapshot) + indexes_by_oid = _indexes_by_relation(snapshot) + citus_by_oid = _citus_by_relation(snapshot) + + tables = [] + for oid, relation in sorted( + relations_by_oid.items(), key=lambda item: _relation_label(item[1]) + ): + if relation.get("relation_kind") not in ("r", "p"): + continue + tables.append( + { + "name": _relation_label(relation), + "columns": [ + { + "name": column.get("column_name"), + "type": column.get("data_type"), + "not_null": column.get("is_not_null") is True, + } + for column in columns_by_oid.get(oid, []) + ], + "indexes": [ + { + "name": index.get("index_name"), + "method": index.get("access_method"), + "unique": index.get("is_unique") is True, + "primary": index.get("is_primary") is True, + "predicate": index.get("predicate_expr"), + } + for index in indexes_by_oid.get(oid, []) + ], + "citus": citus_by_oid.get(oid), + } + ) + + return { + "source_dialect": snapshot.get("source_dialect") + or snapshot.get("database_dialect") + or "postgresql", + "server_version": snapshot.get("server_version"), + "captured_at": snapshot.get("captured_at"), + "tables": tables, + "candidate_indexes": _candidate_indexes(snapshot), + "workload_observations": _workload_observations(snapshot), + "valkey_queue": valkey_queue_config_summary(), + } + + +def generate_index_design_llm_prompt(snapshot: dict) -> str: + """Build a prompt for LLM-assisted table and index design.""" + + summary = _compact_index_design_summary(snapshot) + payload = json.dumps(summary, ensure_ascii=False, indent=2, sort_keys=True) + return "\n".join( + [ + "# ERD Index Design Prompt", + "", + "You are a senior PostgreSQL data architect. Review the supplied", + "ERD snapshot, current indexes, optional EXPLAIN ANALYZE evidence,", + "Citus distribution metadata, and queue configuration. Propose table", + "and index changes only when supported by the snapshot or workload", + "evidence. Use CREATE INDEX CONCURRENTLY for PostgreSQL index DDL", + "and call out when Valkey queueing should be used for asynchronous", + "index work. Do not invent application behavior.", + "", + "Return Markdown with these sections:", + "- Workload assumptions", + "- Table design adjustments", + "- Index recommendations", + "- Citus placement considerations", + "- Valkey queue execution plan", + "- SQL migration draft", + "", + "Snapshot summary:", + "```json", + payload, + "```", + "", + ] + ) + + +def generate_index_design_markdown(snapshot: dict) -> str: + """Generate deterministic index-design guidance from snapshot metadata.""" + + summary = _compact_index_design_summary(snapshot) + lines: list[str] = [ + "# ERD Index Design", + "", + "## Snapshot", + f"- Source dialect: {_text(summary.get('source_dialect'), 'postgresql')}", + f"- Server version: {_text(summary.get('server_version'), '-') or '-'}", + f"- Captured at: {_text(summary.get('captured_at'), '-') or '-'}", + "", + "## Valkey Queue", + ] + + valkey = summary["valkey_queue"] + if isinstance(valkey, dict): + lines.extend( + [ + f"- Enabled: {'yes' if valkey.get('enabled') is True else 'no'}", + f"- Mode: {_escape_cell(valkey.get('mode'))}", + f"- Queue key: {_escape_cell(valkey.get('queue_key'))}", + f"- Sentinel master: {_escape_cell(valkey.get('sentinel_master'))}", + f"- Sentinel count: {_escape_cell(valkey.get('sentinel_count'))}", + ] + ) + + candidates = [ + item for item in summary.get("candidate_indexes", []) if isinstance(item, dict) + ] + lines.extend(["", "## Index Recommendations"]) + if not candidates: + lines.append("_No missing foreign-key helper indexes were detected._") + for candidate in candidates: + lines.extend( + [ + "", + f"### {_escape_cell(candidate.get('index_name'))}", + f"- Table: {_escape_cell(candidate.get('schema_name'))}.{_escape_cell(candidate.get('table_name'))}", + f"- Columns: {_escape_cell(', '.join(candidate.get('columns', [])))}", + f"- Reason: {_escape_cell(candidate.get('reason'))}", + "", + "```sql", + _text(candidate.get("ddl")), + "```", + ] + ) + + lines.extend(["", "## Citus Placement"]) + citus_rows = _rows(snapshot, "citus_distributed_tables") + if not citus_rows: + lines.append("_No Citus distributed table metadata was captured._") + else: + lines.extend( + [ + "| Table | Method | Key | Colocation | Shards | Replicas |", + "| --- | --- | --- | --- | --- | --- |", + ] + ) + for row in citus_rows: + lines.append( + "| " + + " | ".join( + [ + _escape_cell(_relation_label(row)), + _escape_cell(row.get("distribution_method")), + _escape_cell(row.get("distribution_key")), + _escape_cell(row.get("colocation_id")), + _escape_cell(row.get("shard_count")), + _escape_cell(row.get("replication_factor")), + ] + ) + + " |" + ) + + observations = summary.get("workload_observations", []) + lines.extend(["", "## EXPLAIN ANALYZE Evidence"]) + if not observations: + lines.append( + "_No EXPLAIN ANALYZE or workload observations were embedded in this snapshot._" + ) + else: + lines.append("```json") + lines.append(json.dumps(observations, ensure_ascii=False, indent=2)) + lines.append("```") + + lines.extend( + [ + "", + "## LLM Review Prompt", + "Use `/index-design.md?mode=llm-prompt` to generate a compact prompt", + "for an approved LLM provider. When a provider is configured,", + "`/index-design.md?mode=llm-draft` asks the provider to generate", + "a Markdown table/index design directly.", + "", + ] + ) + return "\n".join(lines) + + +def generate_index_design_spec(snapshot: dict, mode: str = "markdown") -> str: + """Generate an index-design document or LLM prompt from a snapshot.""" + if mode == "llm-prompt": + return generate_index_design_llm_prompt(snapshot) + if mode == "markdown": + return generate_index_design_markdown(snapshot) + raise ValueError(f"unsupported index design spec mode: {mode}") diff --git a/backend/build/lib/app/spec/llm.py b/backend/build/lib/app/spec/llm.py new file mode 100644 index 00000000..1d899de6 --- /dev/null +++ b/backend/build/lib/app/spec/llm.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import httpx + +from app.settings import settings +from app.spec.index_design import generate_index_design_llm_prompt +from app.spec.reversing import generate_reversing_llm_prompt + + +class LlmConfigurationError(RuntimeError): + """Raised when live LLM drafting is requested without provider settings.""" + + +class LlmProviderError(RuntimeError): + """Raised when the configured LLM provider does not return a usable draft.""" + + +def _required_setting(value: str | None, name: str) -> str: + if value is None or not value.strip(): + raise LlmConfigurationError(f"{name} is required for live LLM drafts") + return value.strip() + + +def _chat_completions_url(base_url: str) -> str: + return f"{base_url.rstrip('/')}/chat/completions" + + +def _extract_chat_content(payload: object) -> str: + if not isinstance(payload, dict): + raise LlmProviderError("LLM provider returned a non-object response") + choices = payload.get("choices") + if not isinstance(choices, list) or not choices: + raise LlmProviderError("LLM provider response did not include choices") + first_choice = choices[0] + if not isinstance(first_choice, dict): + raise LlmProviderError("LLM provider response choice is invalid") + message = first_choice.get("message") + if not isinstance(message, dict): + raise LlmProviderError("LLM provider response did not include a message") + content = message.get("content") + if not isinstance(content, str) or not content.strip(): + raise LlmProviderError("LLM provider returned an empty draft") + return content.strip() + + +async def generate_reversing_llm_draft( + snapshot: dict, client: httpx.AsyncClient | None = None +) -> str: + """Generate a DB reversing spec with a configured chat-completions provider.""" + + return await _generate_llm_draft( + prompt=generate_reversing_llm_prompt(snapshot), + system_content=( + "You write concise database reverse-engineering specifications " + "from schema metadata. Do not invent facts." + ), + client=client, + ) + + +async def generate_index_design_llm_draft( + snapshot: dict, client: httpx.AsyncClient | None = None +) -> str: + """Generate table/index design guidance with a configured LLM provider.""" + + return await _generate_llm_draft( + prompt=generate_index_design_llm_prompt(snapshot), + system_content=( + "You write concise PostgreSQL table and index design guidance " + "from schema metadata and workload evidence. Do not invent facts." + ), + client=client, + ) + + +async def _generate_llm_draft( + *, + prompt: str, + system_content: str, + client: httpx.AsyncClient | None, +) -> str: + base_url = _required_setting(settings.llm_api_base_url, "LLM_API_BASE_URL") + api_key = _required_setting(settings.llm_api_key, "LLM_API_KEY") + model = _required_setting(settings.llm_model, "LLM_MODEL") + request_json = { + "model": model, + "messages": [ + {"role": "system", "content": system_content}, + {"role": "user", "content": prompt}, + ], + "temperature": 0.2, + } + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + owns_client = client is None + if client is None: + client = httpx.AsyncClient(timeout=settings.llm_timeout_seconds) + try: + try: + response = await client.post( + _chat_completions_url(base_url), + headers=headers, + json=request_json, + ) + except httpx.HTTPError as exc: + raise LlmProviderError("LLM provider request failed") from exc + if response.status_code >= 400: + raise LlmProviderError( + f"LLM provider request failed with HTTP {response.status_code}" + ) + try: + payload = response.json() + except ValueError as exc: + raise LlmProviderError("LLM provider returned invalid JSON") from exc + return _extract_chat_content(payload) + finally: + if owns_client: + await client.aclose() diff --git a/backend/build/lib/app/spec/reversing.py b/backend/build/lib/app/spec/reversing.py new file mode 100644 index 00000000..2146c618 --- /dev/null +++ b/backend/build/lib/app/spec/reversing.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import json +from collections import defaultdict +from typing import Literal + + +SpecMode = Literal["markdown", "llm-prompt"] + + +def _text(value: object, default: str = "") -> str: + return value if isinstance(value, str) else default + + +def _bool_text(value: object) -> str: + return "yes" if value is True else "no" + + +def _escape_cell(value: object) -> str: + text = _text(value, "-") or "-" + return text.replace("|", "\\|").replace("\n", " ") + + +def _relation_label(row: dict) -> str: + schema = _text(row.get("schema_name"), "unknown") + name = _text(row.get("relation_name"), "unknown") + return f"{schema}.{name}" + + +def _relation_kind_name(kind: object) -> str: + if not isinstance(kind, str): + return "relation" + kind_names = { + "r": "table", + "p": "partitioned table", + "v": "view", + "m": "materialized view", + } + return kind_names.get(kind, "relation") + + +def _rows(snapshot: dict, key: str) -> list[dict]: + value = snapshot.get(key) + if not isinstance(value, list): + return [] + return [row for row in value if isinstance(row, dict)] + + +def _compact_snapshot_summary(snapshot: dict) -> dict: + relations = _rows(snapshot, "relations") + columns = _rows(snapshot, "columns") + constraints = _rows(snapshot, "constraints") + indexes = _rows(snapshot, "indexes") + fk_edges = _rows(snapshot, "fk_edges") + + columns_by_oid: dict[int, list[dict]] = defaultdict(list) + for column in columns: + oid = column.get("relation_oid") + if isinstance(oid, int): + columns_by_oid[oid].append( + { + "name": column.get("column_name"), + "type": column.get("data_type"), + "not_null": column.get("is_not_null") is True, + "default": column.get("default_expr"), + "example": column.get("example_value"), + "comment": column.get("column_comment"), + } + ) + + constraints_by_oid: dict[int, list[dict]] = defaultdict(list) + for constraint in constraints: + oid = constraint.get("relation_oid") + if isinstance(oid, int): + constraints_by_oid[oid].append( + { + "name": constraint.get("constraint_name"), + "type": constraint.get("constraint_type"), + "definition": constraint.get("constraint_def"), + } + ) + + indexes_by_oid: dict[int, list[dict]] = defaultdict(list) + for index in indexes: + oid = index.get("relation_oid") + if isinstance(oid, int): + indexes_by_oid[oid].append( + { + "name": index.get("index_name"), + "unique": index.get("is_unique") is True, + "primary": index.get("is_primary") is True, + "method": index.get("access_method"), + "predicate": index.get("predicate_expr"), + } + ) + + objects: list[dict] = [] + for relation in relations: + oid = relation.get("relation_oid") + if not isinstance(oid, int): + continue + objects.append( + { + "name": _relation_label(relation), + "kind": _relation_kind_name(relation.get("relation_kind")), + "comment": relation.get("relation_comment"), + "columns": columns_by_oid.get(oid, []), + "constraints": constraints_by_oid.get(oid, []), + "indexes": indexes_by_oid.get(oid, []), + } + ) + + relationships = [ + { + "constraint": edge.get("fk_constraint_name"), + "from": ( + f"{edge.get('child_schema_name')}.{edge.get('child_relation_name')}" + f".{edge.get('child_column_name')}" + ), + "to": ( + f"{edge.get('parent_schema_name')}.{edge.get('parent_relation_name')}" + f".{edge.get('parent_column_name')}" + ), + } + for edge in fk_edges + ] + + return { + "source_dialect": snapshot.get("source_dialect") + or snapshot.get("database_dialect") + or "postgresql", + "server_version": snapshot.get("server_version"), + "captured_at": snapshot.get("captured_at"), + "schema_filter": snapshot.get("schema_filter"), + "objects": objects, + "relationships": relationships, + } + + +def generate_reversing_llm_prompt(snapshot: dict) -> str: + """Build a provider-neutral prompt for LLM-assisted DB reversing specs.""" + + summary = _compact_snapshot_summary(snapshot) + payload = json.dumps(summary, ensure_ascii=False, indent=2, sort_keys=True) + return "\n".join( + [ + "# DB Reversing Specification Prompt", + "", + "You are a senior data architect. Generate a concise DB reversing", + "specification from the supplied schema snapshot. Focus on business", + "entities, relationships, key constraints, indexing intent, and", + "open questions. Do not invent facts that are not supported by the", + "snapshot. Mark uncertain business meanings as assumptions.", + "", + "Return Markdown with these sections:", + "- Overview", + "- Entity catalog", + "- Relationship model", + "- Constraint and index interpretation", + "- Data quality and modeling questions", + "", + "Snapshot summary:", + "```json", + payload, + "```", + "", + ] + ) + + +def _generate_markdown_header(summary: dict) -> list[str]: + objects = summary.get("objects", []) + relationships = summary.get("relationships", []) + return [ + "# DB Reversing Specification", + "", + "## Snapshot", + f"- Source dialect: {_text(summary.get('source_dialect'), 'postgresql')}", + f"- Server version: {_text(summary.get('server_version'), '-') or '-'}", + f"- Captured at: {_text(summary.get('captured_at'), '-') or '-'}", + f"- Schema filter: {_text(summary.get('schema_filter'), '-') or '-'}", + "", + "## Object Inventory", + f"- Relations: {len(objects)}", + f"- Relationships: {len(relationships)}", + "", + "## Entity Catalog", + ] + + +def _generate_markdown_entity(obj: dict) -> list[str]: + lines = [ + "", + f"### {_text(obj.get('name'), 'unknown')}", + f"- Kind: {_text(obj.get('kind'), 'relation')}", + ] + comment = obj.get("comment") + if isinstance(comment, str) and comment: + lines.append(f"- Comment: {comment}") + + lines.extend( + [ + "", + "| Column | Type | Required | Default | Example | Comment |", + "| --- | --- | --- | --- | --- | --- |", + ] + ) + for column in obj.get("columns", []): + if not isinstance(column, dict): + continue + lines.append( + "| " + + " | ".join( + [ + _escape_cell(column.get("name")), + _escape_cell(column.get("type")), + _bool_text(column.get("not_null")), + _escape_cell(column.get("default")), + _escape_cell(column.get("example")), + _escape_cell(column.get("comment")), + ] + ) + + " |" + ) + + constraints = [ + item for item in obj.get("constraints", []) if isinstance(item, dict) + ] + if constraints: + lines.extend(["", "Constraints:"]) + for constraint in constraints: + lines.append( + "- " + f"{_escape_cell(constraint.get('name'))} " + f"({_escape_cell(constraint.get('type'))}): " + f"{_escape_cell(constraint.get('definition'))}" + ) + + indexes = [item for item in obj.get("indexes", []) if isinstance(item, dict)] + if indexes: + lines.extend(["", "Indexes:"]) + for index in indexes: + intent = [] + if index.get("primary") is True: + intent.append("primary") + if index.get("unique") is True: + intent.append("unique") + if index.get("method"): + intent.append(f"method={index.get('method')}") + if index.get("predicate"): + intent.append("partial") + suffix = f" [{', '.join(intent)}]" if intent else "" + lines.append(f"- {_escape_cell(index.get('name'))}{suffix}") + + return lines + + +def _generate_markdown_entity_catalog(objects: list[dict]) -> list[str]: + if not objects: + return ["", "_No relations were captured in this snapshot._"] + lines = [] + for obj in objects: + lines.extend(_generate_markdown_entity(obj)) + return lines + + +def _generate_markdown_relationship_model(relationships: list[dict]) -> list[str]: + lines = ["", "## Relationship Model"] + if not relationships: + lines.append("_No foreign-key relationships were captured._") + for relationship in relationships: + lines.append( + "- " + f"{_escape_cell(relationship.get('from'))} -> " + f"{_escape_cell(relationship.get('to'))} " + f"({_escape_cell(relationship.get('constraint'))})" + ) + return lines + + +def _generate_markdown_footer() -> list[str]: + return [ + "", + "## LLM Review Prompt", + "Use `/reversing-spec.md?mode=llm-prompt` to generate a compact prompt", + "for an approved LLM provider. The prompt includes only schema metadata", + "from this snapshot and asks the model to mark unsupported business", + "meaning as assumptions.", + "When a provider is configured, `/reversing-spec.md?mode=llm-draft`", + "asks the provider to generate a Markdown draft directly.", + "", + ] + + +def generate_reversing_markdown(snapshot: dict) -> str: + """Generate a deterministic DB reversing specification draft.""" + + summary = _compact_snapshot_summary(snapshot) + lines: list[str] = [] + lines.extend(_generate_markdown_header(summary)) + lines.extend(_generate_markdown_entity_catalog(summary["objects"])) + lines.extend(_generate_markdown_relationship_model(summary["relationships"])) + lines.extend(_generate_markdown_footer()) + return "\n".join(lines) + + +def generate_reversing_spec(snapshot: dict, mode: str = "markdown") -> str: + """Generate a reverse-engineering document or LLM prompt from a snapshot.""" + if mode == "llm-prompt": + return generate_reversing_llm_prompt(snapshot) + if mode == "markdown": + return generate_reversing_markdown(snapshot) + raise ValueError(f"unsupported reversing spec mode: {mode}") diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e99700c5..2db044e0 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -182,17 +182,13 @@ export default function App() { if (!normalizedNodeSearch) return new Set(); const matches = new Set(); for (const node of nodes) { - const haystack = [ - node.data.title, - node.data.comment ?? "", - ...node.data.columns.flatMap((column) => [ - column.column_name, - column.data_type, - column.column_comment ?? "", - ]), - ] - .join(" ") - .toLocaleLowerCase(); + // ⚡ Bolt: Optimize node search by replacing array allocation, flatMap, and join with direct string concatenation. + // This prevents severe garbage collection pressure on high-frequency search renders. + let haystack = node.data.title + " " + (node.data.comment ?? ""); + for (const column of node.data.columns) { + haystack += " " + column.column_name + " " + column.data_type + " " + (column.column_comment ?? ""); + } + haystack = haystack.toLocaleLowerCase(); if (haystack.includes(normalizedNodeSearch)) { matches.add(node.id); }