Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ services:
- app
ports:
- "8080:8080"
environment:
- AGENT_API_USERNAME=vitra_agent_user
- AGENT_API_PASSWORD=vitra_agent_password
env_file:
- .env

Expand Down
102 changes: 82 additions & 20 deletions frontend/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
import os
from fastapi import FastAPI, Request
import httpx
import asyncio
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

# --- App Initialization ---
app = FastAPI()

# Allow all origins for CORS, so the chat can be embedded anywhere.
# --- Configuration ---
AGENT_API_URL = "http://app:8000"
AGENT_API_USERNAME = os.environ.get("AGENT_API_USERNAME", "vitra_agent_user")
AGENT_API_PASSWORD = os.environ.get("AGENT_API_PASSWORD", "vitra_agent_password")

# --- In-memory cache for the agent API token ---
agent_api_token: str | None = None
token_lock = asyncio.Lock()

# --- Middleware ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand All @@ -16,40 +29,89 @@
allow_headers=["*"],
)

# --- Static files and templates ---
current_dir = os.path.dirname(os.path.realpath(__file__))
static_path = os.path.join(current_dir, "static")
templates_path = os.path.join(current_dir, "templates")

app.mount("/static", StaticFiles(directory=static_path), name="static")
templates = Jinja2Templates(directory=templates_path)

# --- Pydantic Models ---
class Message(BaseModel):
content: str

# --- Authentication with Agent App ---
async def force_get_new_agent_api_token() -> str:
"""Fetches a new token from the agent service."""
global agent_api_token
try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{AGENT_API_URL}/token",
data={"username": AGENT_API_USERNAME, "password": AGENT_API_PASSWORD}
)
response.raise_for_status()
token_data = response.json()
agent_api_token = token_data["access_token"]
return agent_api_token
except httpx.HTTPStatusError as e:
print(f"Error getting token: {e.response.status_code} - {e.response.text}")
raise HTTPException(status_code=500, detail="Could not authenticate with the agent service.")
except Exception as e:
print(f"An unexpected error occurred while getting token: {e}")
raise HTTPException(status_code=500, detail="An error occurred while communicating with the agent service.")

async def get_agent_api_token() -> str:
"""
Retrieves the agent API token from cache, fetching a new one if necessary.
This function uses a lock to prevent race conditions.
"""
global agent_api_token
async with token_lock:
if agent_api_token is None:
return await force_get_new_agent_api_token()
return agent_api_token

async def invalidate_token():
"""Invalidates the cached token."""
global agent_api_token
async with token_lock:
agent_api_token = None

# --- API Endpoints ---
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})

async def make_agent_request(method: str, url: str, **kwargs):
"""Makes an authenticated request to the agent, with retry logic."""
token = await get_agent_api_token()
headers = {"Authorization": f"Bearer {token}", **kwargs.pop("headers", {})}

import httpx
from pydantic import BaseModel

class Message(BaseModel):
content: str
async with httpx.AsyncClient() as client:
try:
response = await client.request(method, url, headers=headers, **kwargs)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
await invalidate_token()
token = await get_agent_api_token()
headers["Authorization"] = f"Bearer {token}"

AGENT_API_URL = "http://app:8000"
response = await client.request(method, url, headers=headers, **kwargs)
response.raise_for_status()
return response.json()
raise HTTPException(status_code=e.response.status_code, detail=e.response.text)

@app.post("/api/sessions")
async def create_session():
async with httpx.AsyncClient() as client:
response = await client.post(f"{AGENT_API_URL}/sessions")
response.raise_for_status()
return response.json()
return await make_agent_request("POST", f"{AGENT_API_URL}/sessions")

@app.post("/api/sessions/{session_id}/message")
async def post_message(session_id: str, message: Message):
async with httpx.AsyncClient() as client:
response = await client.post(
f"{AGENT_API_URL}/sessions/{session_id}/message",
json={"content": message.content}
)
response.raise_for_status()
return response.json()
return await make_agent_request(
"POST",
f"{AGENT_API_URL}/sessions/{session_id}/message",
json={"content": message.content}
)
85 changes: 82 additions & 3 deletions vitra_ai/main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,108 @@
import os
from fastapi import FastAPI, HTTPException
from datetime import datetime, timedelta
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from google.adk.runtime import SessionManager
from vitra_ai.agent import root_agent
from dotenv import load_dotenv
from jose import JWTError, jwt
from passlib.context import CryptContext

load_dotenv()

# Security settings
SECRET_KEY = os.environ.get("SECRET_KEY", "a_very_secret_key")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

app = FastAPI()

# Dummy user database
# In a real application, this would be a real database.
FAKE_USER_DB = {
"vitra_agent_user": {
"username": "vitra_agent_user",
"hashed_password": pwd_context.hash("vitra_agent_password"),
}
}

class Token(BaseModel):
access_token: str
token_type: str

class TokenData(BaseModel):
username: str | None = None

class Message(BaseModel):
content: str

def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)

def get_user(username: str):
if username in FAKE_USER_DB:
return FAKE_USER_DB[username]

def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user(token_data.username)
if user is None:
raise credentials_exception
return user

session_manager = SessionManager(
engine=os.environ.get("DATABASE_URL")
)

@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = get_user(form_data.username)
if not user or not verify_password(form_data.password, user["hashed_password"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}


@app.post("/sessions")
async def create_session():
async def create_session(current_user: dict = Depends(get_current_user)):
session = session_manager.create_session()
return {"session_id": session.session_id}

@app.post("/sessions/{session_id}/message")
async def post_message(session_id: str, message: Message):
async def post_message(session_id: str, message: Message, current_user: dict = Depends(get_current_user)):
session = session_manager.get_session(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
Expand Down
2 changes: 2 additions & 0 deletions vitra_ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ dependencies = [
"fastapi>=0.111.0",
"uvicorn>=0.30.1",
"psycopg2-binary>=2.9.9",
"python-jose[cryptography]>=3.3.0",
"passlib[bcrypt]>=1.7.4",
]