diff --git a/aragora/server/fastapi/routes/security.py b/aragora/server/fastapi/routes/security.py index 179b3c8f4c..a33ba25aaa 100644 --- a/aragora/server/fastapi/routes/security.py +++ b/aragora/server/fastapi/routes/security.py @@ -24,12 +24,68 @@ router = APIRouter(prefix="/api/v2", tags=["Security"]) +_HTTP_METHODS = frozenset({"get", "put", "post", "delete", "options", "head", "patch", "trace"}) +_UNPROTECTED_API_PATHS = frozenset( + {"/api/v2/health", "/api/v2/health/ready", "/api/v2/health/live"} +) + def _reject_unexpected_query_params(request: Request) -> None: if request.query_params: raise HTTPException(status_code=400, detail="Invalid query") +def _openapi_operation_paths(app: Any) -> list[str]: + """Return one path entry per OpenAPI operation exposed by a FastAPI app.""" + openapi = getattr(app, "openapi", None) + if not callable(openapi): + return [] + + spec = openapi() + if not isinstance(spec, dict): + return [] + + paths = spec.get("paths", {}) + if not isinstance(paths, dict): + return [] + + operation_paths: list[str] = [] + for path, path_item in paths.items(): + if not isinstance(path, str) or not isinstance(path_item, dict): + continue + for method, operation in path_item.items(): + if ( + isinstance(method, str) + and method.lower() in _HTTP_METHODS + and isinstance(operation, dict) + ): + operation_paths.append(path) + return operation_paths + + +def _legacy_flat_route_paths(app: Any) -> list[str]: + """Best-effort fallback for older Starlette-like test doubles.""" + route_paths: list[str] = [] + for route in getattr(app, "routes", ()) or (): + path = getattr(route, "path", None) + if isinstance(path, str) and hasattr(route, "methods"): + route_paths.append(path) + return route_paths + + +def _rbac_coverage_route_paths(app: Any) -> list[str]: + operation_paths = _openapi_operation_paths(app) + if operation_paths: + return operation_paths + return _legacy_flat_route_paths(app) + + +def _is_unprotected_endpoint_path(path: str) -> bool: + if not path.startswith("/api/"): + return True + return path in _UNPROTECTED_API_PATHS + + # ============================================================================= # RBAC Coverage # ============================================================================= @@ -78,32 +134,19 @@ async def get_rbac_coverage( logger.debug("Live RBAC assignments unavailable: %s", exc) # ----- Endpoint coverage ----- - # Count total registered routes on the FastAPI app. - # Route counts are more stable than method counts across FastAPI/Starlette - # versions, and the dashboard field is endpoint-oriented. - total_endpoints = 0 + # FastAPI 0.137 preserves included routers as a tree, so app.routes no + # longer reliably exposes every included API route. Count public OpenAPI + # operations first, with a guarded route-list fallback for older test doubles. + endpoint_paths: list[str] = [] try: - for route in request.app.routes: - if hasattr(route, "methods"): - total_endpoints += 1 + endpoint_paths = _rbac_coverage_route_paths(request.app) except (RuntimeError, TypeError, AttributeError): - pass + endpoint_paths = [] # Heuristic: endpoints behind RBAC middleware are "protected". - # The RBAC middleware protects all /api/v2/* routes except health, - # so the unprotected set is small (health + docs + openapi.json). - unprotected = 0 - try: - for route in request.app.routes: - path = getattr(route, "path", "") - if path and not path.startswith("/api/"): - if hasattr(route, "methods"): - unprotected += 1 - elif path in ("/api/v2/health", "/api/v2/health/ready", "/api/v2/health/live"): - if hasattr(route, "methods"): - unprotected += 1 - except (RuntimeError, TypeError, AttributeError): - pass + # The RBAC middleware protects all /api/v2/* routes except health. + total_endpoints = len(endpoint_paths) + unprotected = sum(1 for path in endpoint_paths if _is_unprotected_endpoint_path(path)) if total_endpoints == 0: total_endpoints = 1 # prevent division by zero diff --git a/tests/server/fastapi/test_security_routes.py b/tests/server/fastapi/test_security_routes.py index d934d074f2..fd621c9711 100644 --- a/tests/server/fastapi/test_security_routes.py +++ b/tests/server/fastapi/test_security_routes.py @@ -2,6 +2,13 @@ from unittest.mock import MagicMock +from fastapi import APIRouter, FastAPI +from fastapi.testclient import TestClient + +from aragora.rbac.models import AuthorizationContext +from aragora.server.fastapi.dependencies.auth import require_authenticated +from aragora.server.fastapi.routes import security + def test_security_routes_require_auth(fastapi_client): response = fastapi_client.get("/api/v2/security/rbac-coverage") @@ -50,6 +57,39 @@ def test_rbac_coverage_maps_assignment_failures_to_safe_summary( assert response.json()["data"]["assignments_active"] == 0 +def test_rbac_coverage_counts_openapi_operations_for_included_router_tree(): + app = FastAPI() + parent = APIRouter(prefix="/api/v2") + nested = APIRouter(prefix="/protected") + + @nested.get("/covered") + async def covered_route() -> dict[str, bool]: + return {"ok": True} + + parent.include_router(nested) + app.include_router(parent) + app.include_router(security.router) + checker = MagicMock() + checker.list_assignments.return_value = [] + app.state.context = {"rbac_checker": checker} + app.dependency_overrides[require_authenticated] = lambda: AuthorizationContext( + user_id="user-1", + org_id="org-1", + workspace_id="ws-1", + roles={"admin"}, + permissions={"*"}, + ) + + with TestClient(app, raise_server_exceptions=False) as client: + response = client.get("/api/v2/security/rbac-coverage") + + assert response.status_code == 200 + data = response.json()["data"] + assert data["total_endpoints"] == 3 + assert data["unprotected_endpoints"] == 0 + assert data["coverage_percent"] == 100.0 + + def test_encryption_status_maps_tls_failures_to_degraded( fastapi_client, override_auth, monkeypatch ):