From dc091e61b3454d28446a1227484a32ff506df478 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:06:51 +0000 Subject: [PATCH] feat: Add JWT authentication to the agent API This commit introduces JWT-based authentication to the agent application and updates the frontend application to use it. Key changes: - A new `/token` endpoint is added to the agent application to issue JWTs based on username/password credentials. - The `/sessions` and `/sessions/{session_id}/message` endpoints are now protected and require a valid JWT. - The `frontend/main.py` application is updated to fetch a JWT from the agent and include it in all subsequent API calls. - Token handling in the frontend is optimized with in-memory caching and a robust retry mechanism for handling token expiration. - `docker-compose.yml` is updated to provide the necessary API credentials to the frontend service via environment variables. --- docker-compose.yml | 3 ++ frontend/main.py | 102 ++++++++++++++++++++++++++++++++-------- vitra_ai/main.py | 85 +++++++++++++++++++++++++++++++-- vitra_ai/pyproject.toml | 2 + 4 files changed, 169 insertions(+), 23 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 8bf205b..55adf43 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,6 +38,9 @@ services: - app ports: - "8080:8080" + environment: + - AGENT_API_USERNAME=vitra_agent_user + - AGENT_API_PASSWORD=vitra_agent_password env_file: - .env diff --git a/frontend/main.py b/frontend/main.py index 4e613a6..8ed024b 100644 --- a/frontend/main.py +++ b/frontend/main.py @@ -1,13 +1,26 @@ import os -from fastapi import FastAPI, Request +import httpx +import asyncio +from fastapi import FastAPI, Request, HTTPException from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +# --- App Initialization --- app = FastAPI() -# Allow all origins for CORS, so the chat can be embedded anywhere. +# --- Configuration --- +AGENT_API_URL = "http://app:8000" +AGENT_API_USERNAME = os.environ.get("AGENT_API_USERNAME", "vitra_agent_user") +AGENT_API_PASSWORD = os.environ.get("AGENT_API_PASSWORD", "vitra_agent_password") + +# --- In-memory cache for the agent API token --- +agent_api_token: str | None = None +token_lock = asyncio.Lock() + +# --- Middleware --- app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -16,40 +29,89 @@ allow_headers=["*"], ) +# --- Static files and templates --- current_dir = os.path.dirname(os.path.realpath(__file__)) static_path = os.path.join(current_dir, "static") templates_path = os.path.join(current_dir, "templates") - app.mount("/static", StaticFiles(directory=static_path), name="static") templates = Jinja2Templates(directory=templates_path) +# --- Pydantic Models --- +class Message(BaseModel): + content: str + +# --- Authentication with Agent App --- +async def force_get_new_agent_api_token() -> str: + """Fetches a new token from the agent service.""" + global agent_api_token + try: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{AGENT_API_URL}/token", + data={"username": AGENT_API_USERNAME, "password": AGENT_API_PASSWORD} + ) + response.raise_for_status() + token_data = response.json() + agent_api_token = token_data["access_token"] + return agent_api_token + except httpx.HTTPStatusError as e: + print(f"Error getting token: {e.response.status_code} - {e.response.text}") + raise HTTPException(status_code=500, detail="Could not authenticate with the agent service.") + except Exception as e: + print(f"An unexpected error occurred while getting token: {e}") + raise HTTPException(status_code=500, detail="An error occurred while communicating with the agent service.") + +async def get_agent_api_token() -> str: + """ + Retrieves the agent API token from cache, fetching a new one if necessary. + This function uses a lock to prevent race conditions. + """ + global agent_api_token + async with token_lock: + if agent_api_token is None: + return await force_get_new_agent_api_token() + return agent_api_token +async def invalidate_token(): + """Invalidates the cached token.""" + global agent_api_token + async with token_lock: + agent_api_token = None + +# --- API Endpoints --- @app.get("/", response_class=HTMLResponse) async def read_root(request: Request): return templates.TemplateResponse("index.html", {"request": request}) +async def make_agent_request(method: str, url: str, **kwargs): + """Makes an authenticated request to the agent, with retry logic.""" + token = await get_agent_api_token() + headers = {"Authorization": f"Bearer {token}", **kwargs.pop("headers", {})} -import httpx -from pydantic import BaseModel - -class Message(BaseModel): - content: str + async with httpx.AsyncClient() as client: + try: + response = await client.request(method, url, headers=headers, **kwargs) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + await invalidate_token() + token = await get_agent_api_token() + headers["Authorization"] = f"Bearer {token}" -AGENT_API_URL = "http://app:8000" + response = await client.request(method, url, headers=headers, **kwargs) + response.raise_for_status() + return response.json() + raise HTTPException(status_code=e.response.status_code, detail=e.response.text) @app.post("/api/sessions") async def create_session(): - async with httpx.AsyncClient() as client: - response = await client.post(f"{AGENT_API_URL}/sessions") - response.raise_for_status() - return response.json() + return await make_agent_request("POST", f"{AGENT_API_URL}/sessions") @app.post("/api/sessions/{session_id}/message") async def post_message(session_id: str, message: Message): - async with httpx.AsyncClient() as client: - response = await client.post( - f"{AGENT_API_URL}/sessions/{session_id}/message", - json={"content": message.content} - ) - response.raise_for_status() - return response.json() + return await make_agent_request( + "POST", + f"{AGENT_API_URL}/sessions/{session_id}/message", + json={"content": message.content} + ) diff --git a/vitra_ai/main.py b/vitra_ai/main.py index 43dd8ff..2201dc8 100644 --- a/vitra_ai/main.py +++ b/vitra_ai/main.py @@ -1,29 +1,108 @@ import os -from fastapi import FastAPI, HTTPException +from datetime import datetime, timedelta +from fastapi import FastAPI, HTTPException, Depends, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel from google.adk.runtime import SessionManager from vitra_ai.agent import root_agent from dotenv import load_dotenv +from jose import JWTError, jwt +from passlib.context import CryptContext load_dotenv() +# Security settings +SECRET_KEY = os.environ.get("SECRET_KEY", "a_very_secret_key") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + app = FastAPI() +# Dummy user database +# In a real application, this would be a real database. +FAKE_USER_DB = { + "vitra_agent_user": { + "username": "vitra_agent_user", + "hashed_password": pwd_context.hash("vitra_agent_password"), + } +} + +class Token(BaseModel): + access_token: str + token_type: str + +class TokenData(BaseModel): + username: str | None = None + class Message(BaseModel): content: str +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + +def get_user(username: str): + if username in FAKE_USER_DB: + return FAKE_USER_DB[username] + +def create_access_token(data: dict, expires_delta: timedelta | None = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +async def get_current_user(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except JWTError: + raise credentials_exception + user = get_user(token_data.username) + if user is None: + raise credentials_exception + return user + session_manager = SessionManager( engine=os.environ.get("DATABASE_URL") ) +@app.post("/token", response_model=Token) +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + user = get_user(form_data.username) + if not user or not verify_password(form_data.password, user["hashed_password"]): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token( + data={"sub": user["username"]}, expires_delta=access_token_expires + ) + return {"access_token": access_token, "token_type": "bearer"} + @app.post("/sessions") -async def create_session(): +async def create_session(current_user: dict = Depends(get_current_user)): session = session_manager.create_session() return {"session_id": session.session_id} @app.post("/sessions/{session_id}/message") -async def post_message(session_id: str, message: Message): +async def post_message(session_id: str, message: Message, current_user: dict = Depends(get_current_user)): session = session_manager.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") diff --git a/vitra_ai/pyproject.toml b/vitra_ai/pyproject.toml index 3f96e5e..1ac3e11 100644 --- a/vitra_ai/pyproject.toml +++ b/vitra_ai/pyproject.toml @@ -15,4 +15,6 @@ dependencies = [ "fastapi>=0.111.0", "uvicorn>=0.30.1", "psycopg2-binary>=2.9.9", + "python-jose[cryptography]>=3.3.0", + "passlib[bcrypt]>=1.7.4", ] \ No newline at end of file