Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,6 @@ CALLBACK_READ_TIMEOUT = 10

# require as a env if you want to use doc transformation
OPENAI_API_KEY=""

KAAPI_GUARDRAILS_AUTH=""
KAAPI_GUARDRAILS_URL=""
5 changes: 5 additions & 0 deletions .env.test.example
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ AWS_S3_BUCKET_PREFIX="bucket-prefix-name"
# Callback Timeouts (in seconds)
CALLBACK_CONNECT_TIMEOUT = 3
CALLBACK_READ_TIMEOUT = 10

OPENAI_API_KEY=""

KAAPI_GUARDRAILS_AUTH=""
KAAPI_GUARDRAILS_URL=""
2 changes: 2 additions & 0 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Settings(BaseSettings):
POSTGRES_USER: str
POSTGRES_PASSWORD: str = ""
POSTGRES_DB: str = ""
KAAPI_GUARDRAILS_AUTH: str = ""
KAAPI_GUARDRAILS_URL: str = ""

@computed_field # type: ignore[prop-decorator]
@property
Expand Down
14 changes: 14 additions & 0 deletions backend/app/models/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,20 @@ class LLMCallRequest(SQLModel):
"in production, always use the id + version."
),
)
input_guardrails: list[dict[str, Any]] | None = Field(
default=None,
description=(
"Optional guardrails configuration to apply input validation. "
"If not provided, no guardrails will be applied."
),
)
output_guardrails: list[dict[str, Any]] | None = Field(
default=None,
description=(
"Optional guardrails configuration to apply output validation. "
"If not provided, no guardrails will be applied."
),
)
callback_url: HttpUrl | None = Field(
default=None, description="Webhook URL for async response delivery"
)
Expand Down
60 changes: 60 additions & 0 deletions backend/app/services/llm/guardrails.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any
from uuid import UUID
import logging

import httpx

from app.core.config import settings

logger = logging.getLogger(__name__)


def call_guardrails(
input_text: str, guardrail_config: list[dict], job_id: UUID
) -> dict[str, Any]:
"""
Call the Kaapi guardrails service to validate and process input text.

Args:
input_text: Text to validate and process.
guardrail_config: List of validator configurations to apply.
job_id: Unique identifier for the request.

Returns:
JSON response from the guardrails service with validation results.
"""
payload = {
"request_id": str(job_id),
"input": input_text,
"validators": guardrail_config,
}

headers = {
"accept": "application/json",
"Authorization": f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}",
"Content-Type": "application/json",
}

try:
with httpx.Client(timeout=10.0) as client:
response = client.post(
settings.KAAPI_GUARDRAILS_URL,
json=payload,
headers=headers,
)

response.raise_for_status()
return response.json()
except Exception as e:
logger.warning(
f"[call_guardrails] Service unavailable. Bypassing guardrails. job_id={job_id}. error={e}"
)

return {
"success": False,
"bypassed": True,
"data": {
"safe_text": input_text,
"rephrase_needed": False,
},
}
70 changes: 70 additions & 0 deletions backend/app/services/llm/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.crud.jobs import JobCrud
from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest
from app.models.llm.request import ConfigBlob, LLMCallConfig, KaapiCompletionConfig
from app.services.llm.guardrails import call_guardrails
from app.services.llm.providers.registry import get_llm_provider
from app.services.llm.mappers import transform_kaapi_config_to_native
from app.utils import APIResponse, send_callback
Expand Down Expand Up @@ -134,6 +135,9 @@ def execute_job(

# one of (id, version) or blob is guaranteed to be present due to prior validation
config = request.config
input_query = request.query.input
input_guardrails = request.input_guardrails
output_guardrails = request.output_guardrails
callback_response = None
config_blob: ConfigBlob | None = None

Expand All @@ -142,6 +146,36 @@ def execute_job(
)

try:
if input_guardrails:
safe_input = call_guardrails(input_query, input_guardrails, job_id)

logger.info(
f"[execute_job] Input guardrail validation | success={safe_input['success']}."
)

if safe_input.get("bypassed"):
logger.info("[execute_job] Guardrails bypassed (service unavailable)")

elif safe_input["success"]:
request.query.input = safe_input["data"]["safe_text"]

if safe_input["data"]["rephrase_needed"]:
callback_response = APIResponse.failure_response(
error=request.query.input,
metadata=request.request_metadata,
)
return handle_job_error(
job_id, request.callback_url, callback_response
)
else:
request.query.input = safe_input["error"]

callback_response = APIResponse.failure_response(
error=safe_input["error"],
metadata=request.request_metadata,
)
return handle_job_error(job_id, request.callback_url, callback_response)

with Session(engine) as session:
# Update job status to PROCESSING
job_crud = JobCrud(session=session)
Expand Down Expand Up @@ -226,6 +260,42 @@ def execute_job(
)

if response:
if output_guardrails:
output_text = response.response.output.text
safe_output = call_guardrails(output_text, output_guardrails, job_id)

logger.info(
f"[execute_job] Output guardrail validation | success={safe_output['success']}."
)

if safe_output.get("bypassed"):
logger.info(
"[execute_job] Guardrails bypassed (service unavailable)"
)

elif safe_output["success"]:
response.response.output.text = safe_output["data"]["safe_text"]

if safe_output["data"]["rephrase_needed"] == True:
callback_response = APIResponse.failure_response(
error=request.query.input,
metadata=request.request_metadata,
)
return handle_job_error(
job_id, request.callback_url, callback_response
)

else:
response.response.output.text = safe_output["error"]

callback_response = APIResponse.failure_response(
error=safe_output["error"],
metadata=request.request_metadata,
)
return handle_job_error(
job_id, request.callback_url, callback_response
)

callback_response = APIResponse.success_response(
data=response, metadata=request.request_metadata
)
Expand Down
110 changes: 109 additions & 1 deletion backend/app/tests/api/routes/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from app.models.llm.request import (
QueryParams,
LLMCallConfig,
CompletionConfig,
ConfigBlob,
KaapiLLMParams,
KaapiCompletionConfig,
Expand Down Expand Up @@ -164,3 +163,112 @@ def test_llm_call_invalid_provider(
)

assert response.status_code == 422


def test_llm_call_success_with_guardrails(
client: TestClient,
user_api_key_header: dict[str, str],
) -> None:
"""Test successful LLM call when guardrails are enabled (no validators)."""

with (
patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job,
patch("app.services.llm.guardrails.call_guardrails") as mock_guardrails,
):
mock_start_job.return_value = "test-task-id"

mock_guardrails.return_value = {
"success": True,
"bypassed": False,
"data": {
"safe_text": "What is the capital of France?",
"rephrase_needed": False,
},
}

payload = LLMCallRequest(
query=QueryParams(input="What is the capital of France?"),
config=LLMCallConfig(
blob=ConfigBlob(
completion=NativeCompletionConfig(
provider="openai-native",
params={
"model": "gpt-4o",
"temperature": 0.7,
},
)
)
),
input_guardrails=[],
output_guardrails=[],
callback_url="https://example.com/callback",
)

response = client.post(
"/api/v1/llm/call",
json=payload.model_dump(mode="json"),
headers=user_api_key_header,
)

assert response.status_code == 200

body = response.json()
assert body["success"] is True
assert "response is being generated" in body["data"]["message"]

mock_start_job.assert_called_once()
mock_guardrails.assert_not_called()


def test_llm_call_guardrails_bypassed_still_succeeds(
client: TestClient,
user_api_key_header: dict[str, str],
) -> None:
"""If guardrails service is unavailable (bypassed), request should still succeed."""

with (
patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job,
patch("app.services.llm.guardrails.call_guardrails") as mock_guardrails,
):
mock_start_job.return_value = "test-task-id"

mock_guardrails.return_value = {
"success": True,
"bypassed": True,
"data": {
"safe_text": "What is the capital of France?",
"rephrase_needed": False,
},
}

payload = LLMCallRequest(
query=QueryParams(input="What is the capital of France?"),
config=LLMCallConfig(
blob=ConfigBlob(
completion=NativeCompletionConfig(
provider="openai-native",
params={
"model": "gpt-4",
"temperature": 0.7,
},
)
)
),
input_guardrails=[{"type": "pii_remover"}],
output_guardrails=[],
callback_url="https://example.com/callback",
)

response = client.post(
"/api/v1/llm/call",
json=payload.model_dump(mode="json"),
headers=user_api_key_header,
)

assert response.status_code == 200

body = response.json()
assert body["success"] is True
assert "response is being generated" in body["data"]["message"]

mock_start_job.assert_called_once()
Loading