diff --git a/.env.example b/.env.example index 60c3ec58a..98ac7d108 100644 --- a/.env.example +++ b/.env.example @@ -81,3 +81,6 @@ CALLBACK_READ_TIMEOUT = 10 # require as a env if you want to use doc transformation OPENAI_API_KEY="" + +KAAPI_GUARDRAILS_AUTH="" +KAAPI_GUARDRAILS_URL="" diff --git a/.env.test.example b/.env.test.example index f938561d9..9b42d15d9 100644 --- a/.env.test.example +++ b/.env.test.example @@ -32,3 +32,8 @@ AWS_S3_BUCKET_PREFIX="bucket-prefix-name" # Callback Timeouts (in seconds) CALLBACK_CONNECT_TIMEOUT = 3 CALLBACK_READ_TIMEOUT = 10 + +OPENAI_API_KEY="" + +KAAPI_GUARDRAILS_AUTH="" +KAAPI_GUARDRAILS_URL="" diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 40c770541..a7cb7376a 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -49,6 +49,8 @@ class Settings(BaseSettings): POSTGRES_USER: str POSTGRES_PASSWORD: str = "" POSTGRES_DB: str = "" + KAAPI_GUARDRAILS_AUTH: str = "" + KAAPI_GUARDRAILS_URL: str = "" @computed_field # type: ignore[prop-decorator] @property diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index fc44235f9..4da892e92 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -208,6 +208,20 @@ class LLMCallRequest(SQLModel): "in production, always use the id + version." ), ) + input_guardrails: list[dict[str, Any]] | None = Field( + default=None, + description=( + "Optional guardrails configuration to apply input validation. " + "If not provided, no guardrails will be applied." + ), + ) + output_guardrails: list[dict[str, Any]] | None = Field( + default=None, + description=( + "Optional guardrails configuration to apply output validation. " + "If not provided, no guardrails will be applied." + ), + ) callback_url: HttpUrl | None = Field( default=None, description="Webhook URL for async response delivery" ) diff --git a/backend/app/services/llm/guardrails.py b/backend/app/services/llm/guardrails.py new file mode 100644 index 000000000..37f0d1ebf --- /dev/null +++ b/backend/app/services/llm/guardrails.py @@ -0,0 +1,60 @@ +from typing import Any +from uuid import UUID +import logging + +import httpx + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +def call_guardrails( + input_text: str, guardrail_config: list[dict], job_id: UUID +) -> dict[str, Any]: + """ + Call the Kaapi guardrails service to validate and process input text. + + Args: + input_text: Text to validate and process. + guardrail_config: List of validator configurations to apply. + job_id: Unique identifier for the request. + + Returns: + JSON response from the guardrails service with validation results. + """ + payload = { + "request_id": str(job_id), + "input": input_text, + "validators": guardrail_config, + } + + headers = { + "accept": "application/json", + "Authorization": f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}", + "Content-Type": "application/json", + } + + try: + with httpx.Client(timeout=10.0) as client: + response = client.post( + settings.KAAPI_GUARDRAILS_URL, + json=payload, + headers=headers, + ) + + response.raise_for_status() + return response.json() + except Exception as e: + logger.warning( + f"[call_guardrails] Service unavailable. Bypassing guardrails. job_id={job_id}. error={e}" + ) + + return { + "success": False, + "bypassed": True, + "data": { + "safe_text": input_text, + "rephrase_needed": False, + }, + } diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index f4700b51b..492c1dc2c 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -13,6 +13,7 @@ from app.crud.jobs import JobCrud from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest from app.models.llm.request import ConfigBlob, LLMCallConfig, KaapiCompletionConfig +from app.services.llm.guardrails import call_guardrails from app.services.llm.providers.registry import get_llm_provider from app.services.llm.mappers import transform_kaapi_config_to_native from app.utils import APIResponse, send_callback @@ -134,6 +135,9 @@ def execute_job( # one of (id, version) or blob is guaranteed to be present due to prior validation config = request.config + input_query = request.query.input + input_guardrails = request.input_guardrails + output_guardrails = request.output_guardrails callback_response = None config_blob: ConfigBlob | None = None @@ -142,6 +146,36 @@ def execute_job( ) try: + if input_guardrails: + safe_input = call_guardrails(input_query, input_guardrails, job_id) + + logger.info( + f"[execute_job] Input guardrail validation | success={safe_input['success']}." + ) + + if safe_input.get("bypassed"): + logger.info("[execute_job] Guardrails bypassed (service unavailable)") + + elif safe_input["success"]: + request.query.input = safe_input["data"]["safe_text"] + + if safe_input["data"]["rephrase_needed"]: + callback_response = APIResponse.failure_response( + error=request.query.input, + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + else: + request.query.input = safe_input["error"] + + callback_response = APIResponse.failure_response( + error=safe_input["error"], + metadata=request.request_metadata, + ) + return handle_job_error(job_id, request.callback_url, callback_response) + with Session(engine) as session: # Update job status to PROCESSING job_crud = JobCrud(session=session) @@ -226,6 +260,42 @@ def execute_job( ) if response: + if output_guardrails: + output_text = response.response.output.text + safe_output = call_guardrails(output_text, output_guardrails, job_id) + + logger.info( + f"[execute_job] Output guardrail validation | success={safe_output['success']}." + ) + + if safe_output.get("bypassed"): + logger.info( + "[execute_job] Guardrails bypassed (service unavailable)" + ) + + elif safe_output["success"]: + response.response.output.text = safe_output["data"]["safe_text"] + + if safe_output["data"]["rephrase_needed"] == True: + callback_response = APIResponse.failure_response( + error=request.query.input, + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + + else: + response.response.output.text = safe_output["error"] + + callback_response = APIResponse.failure_response( + error=safe_output["error"], + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + callback_response = APIResponse.success_response( data=response, metadata=request.request_metadata ) diff --git a/backend/app/tests/api/routes/test_llm.py b/backend/app/tests/api/routes/test_llm.py index 9313750a0..279911d8b 100644 --- a/backend/app/tests/api/routes/test_llm.py +++ b/backend/app/tests/api/routes/test_llm.py @@ -6,7 +6,6 @@ from app.models.llm.request import ( QueryParams, LLMCallConfig, - CompletionConfig, ConfigBlob, KaapiLLMParams, KaapiCompletionConfig, @@ -164,3 +163,112 @@ def test_llm_call_invalid_provider( ) assert response.status_code == 422 + + +def test_llm_call_success_with_guardrails( + client: TestClient, + user_api_key_header: dict[str, str], +) -> None: + """Test successful LLM call when guardrails are enabled (no validators).""" + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job, + patch("app.services.llm.guardrails.call_guardrails") as mock_guardrails, + ): + mock_start_job.return_value = "test-task-id" + + mock_guardrails.return_value = { + "success": True, + "bypassed": False, + "data": { + "safe_text": "What is the capital of France?", + "rephrase_needed": False, + }, + } + + payload = LLMCallRequest( + query=QueryParams(input="What is the capital of France?"), + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + params={ + "model": "gpt-4o", + "temperature": 0.7, + }, + ) + ) + ), + input_guardrails=[], + output_guardrails=[], + callback_url="https://example.com/callback", + ) + + response = client.post( + "/api/v1/llm/call", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + + body = response.json() + assert body["success"] is True + assert "response is being generated" in body["data"]["message"] + + mock_start_job.assert_called_once() + mock_guardrails.assert_not_called() + + +def test_llm_call_guardrails_bypassed_still_succeeds( + client: TestClient, + user_api_key_header: dict[str, str], +) -> None: + """If guardrails service is unavailable (bypassed), request should still succeed.""" + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_start_job, + patch("app.services.llm.guardrails.call_guardrails") as mock_guardrails, + ): + mock_start_job.return_value = "test-task-id" + + mock_guardrails.return_value = { + "success": True, + "bypassed": True, + "data": { + "safe_text": "What is the capital of France?", + "rephrase_needed": False, + }, + } + + payload = LLMCallRequest( + query=QueryParams(input="What is the capital of France?"), + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + params={ + "model": "gpt-4", + "temperature": 0.7, + }, + ) + ) + ), + input_guardrails=[{"type": "pii_remover"}], + output_guardrails=[], + callback_url="https://example.com/callback", + ) + + response = client.post( + "/api/v1/llm/call", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + + body = response.json() + assert body["success"] is True + assert "response is being generated" in body["data"]["message"] + + mock_start_job.assert_called_once() diff --git a/backend/app/tests/services/llm/test_guardrails.py b/backend/app/tests/services/llm/test_guardrails.py new file mode 100644 index 000000000..4443aecad --- /dev/null +++ b/backend/app/tests/services/llm/test_guardrails.py @@ -0,0 +1,98 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +import httpx + +from app.services.llm.guardrails import call_guardrails +from app.core.config import settings + + +TEST_JOB_ID = uuid.uuid4() +TEST_TEXT = "hello world" +TEST_CONFIG = [{"type": "pii_remover"}] + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_call_guardrails_success(mock_client_cls) -> None: + mock_response = MagicMock() + mock_response.json.return_value = {"success": True} + mock_response.raise_for_status.return_value = None + + mock_client = MagicMock() + mock_client.post.return_value = mock_response + mock_client_cls.return_value.__enter__.return_value = mock_client + + result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + + assert result == {"success": True} + mock_client.post.assert_called_once() + + args, kwargs = mock_client.post.call_args + + assert kwargs["json"]["input"] == TEST_TEXT + assert kwargs["json"]["validators"] == TEST_CONFIG + assert kwargs["json"]["request_id"] == str(TEST_JOB_ID) + + assert kwargs["headers"]["Authorization"].startswith("Bearer ") + assert kwargs["headers"]["Content-Type"] == "application/json" + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_call_guardrails_http_error_bypasses(mock_client_cls) -> None: + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "bad", request=None, response=None + ) + + mock_client = MagicMock() + mock_client.post.return_value = mock_response + mock_client_cls.return_value.__enter__.return_value = mock_client + + result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + + assert result["success"] is False + assert result["bypassed"] is True + assert result["data"]["safe_text"] == TEST_TEXT + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_call_guardrails_network_failure_bypasses(mock_client_cls) -> None: + mock_client = MagicMock() + mock_client.post.side_effect = httpx.ConnectError("failed") + mock_client_cls.return_value.__enter__.return_value = mock_client + + result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + + assert result["bypassed"] is True + assert result["data"]["safe_text"] == TEST_TEXT + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_call_guardrails_timeout_bypasses(mock_client_cls) -> None: + mock_client = MagicMock() + mock_client.post.side_effect = httpx.TimeoutException("timeout") + mock_client_cls.return_value.__enter__.return_value = mock_client + + result = call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + + assert result["bypassed"] is True + + +@patch("app.services.llm.guardrails.httpx.Client") +def test_call_guardrails_uses_settings(mock_client_cls) -> None: + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = {"ok": True} + + mock_client = MagicMock() + mock_client.post.return_value = mock_response + mock_client_cls.return_value.__enter__.return_value = mock_client + + call_guardrails(TEST_TEXT, TEST_CONFIG, TEST_JOB_ID) + + _, kwargs = mock_client.post.call_args + + assert ( + kwargs["headers"]["Authorization"] == f"Bearer {settings.KAAPI_GUARDRAILS_AUTH}" + ) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 0aa3ad1f0..bae15c26b 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -719,6 +719,192 @@ def test_kaapi_config_warnings_merged_with_existing_metadata( assert "reasoning" in result["metadata"]["warnings"][0].lower() assert "does not support reasoning" in result["metadata"]["warnings"][0] + def test_guardrails_sanitize_input_before_provider( + self, db, job_env, job_for_execution + ): + """ + Input guardrails should sanitize the text BEFORE provider.execute is called. + """ + + env = job_env + + env["provider"].execute.return_value = ( + env["mock_llm_response"], + None, + ) + + unsafe_input = "My credit card is 4111 1111 1111 1111" + + with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + mock_guardrails.return_value = { + "success": True, + "bypassed": False, + "data": { + "safe_text": "My credit card is [REDACTED]", + "rephrase_needed": False, + }, + } + + request_data = { + "query": {"input": unsafe_input}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "params": {"model": "gpt-4"}, + } + } + }, + "input_guardrails": [{"type": "pii_remover"}], + "output_guardrails": [], + "include_provider_raw_response": False, + "callback_url": None, + } + + result = self._execute_job(job_for_execution, db, request_data) + + provider_query = env["provider"].execute.call_args[0][1] + assert "[REDACTED]" in provider_query.input + assert "4111" not in provider_query.input + + assert result["success"] + + def test_guardrails_sanitize_output_after_provider( + self, db, job_env, job_for_execution + ): + env = job_env + + env["mock_llm_response"].response.output.text = "Aadhar no 123-45-6789" + env["provider"].execute.return_value = (env["mock_llm_response"], None) + + with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + mock_guardrails.return_value = { + "success": True, + "bypassed": False, + "data": { + "safe_text": "Aadhar [REDACTED]", + "rephrase_needed": False, + }, + } + + request_data = { + "query": {"input": "hello"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "params": {"model": "gpt-4"}, + } + } + }, + "input_guardrails": [], + "output_guardrails": [{"type": "pii_remover"}], + } + + result = self._execute_job(job_for_execution, db, request_data) + + assert "REDACTED" in result["data"]["response"]["output"]["text"] + + def test_guardrails_bypass_does_not_modify_input( + self, db, job_env, job_for_execution + ): + env = job_env + + env["provider"].execute.return_value = (env["mock_llm_response"], None) + + unsafe_input = "4111 1111 1111 1111" + + with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + mock_guardrails.return_value = { + "success": True, + "bypassed": True, + "data": { + "safe_text": unsafe_input, + "rephrase_needed": False, + }, + } + + request_data = { + "query": {"input": unsafe_input}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "params": {"model": "gpt-4"}, + } + } + }, + "input_guardrails": [{"type": "pii_remover"}], + } + + self._execute_job(job_for_execution, db, request_data) + + provider_query = env["provider"].execute.call_args[0][1] + assert provider_query.input == unsafe_input + + def test_guardrails_validation_failure_blocks_job( + self, db, job_env, job_for_execution + ): + env = job_env + + with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + mock_guardrails.return_value = { + "success": False, + "error": "Unsafe content detected", + } + + request_data = { + "query": {"input": "bad input"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "params": {"model": "gpt-4"}, + } + } + }, + "input_guardrails": [{"type": "uli_slur_match"}], + } + + result = self._execute_job(job_for_execution, db, request_data) + + assert not result["success"] + assert "Unsafe content" in result["error"] + env["provider"].execute.assert_not_called() + + def test_guardrails_rephrase_needed_blocks_job( + self, db, job_env, job_for_execution + ): + env = job_env + + with patch("app.services.llm.jobs.call_guardrails") as mock_guardrails: + mock_guardrails.return_value = { + "success": True, + "bypassed": False, + "data": { + "safe_text": "Rephrased text", + "rephrase_needed": True, + }, + } + + request_data = { + "query": {"input": "unsafe text"}, + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "params": {"model": "gpt-4"}, + } + } + }, + "input_guardrails": [{"type": "policy"}], + } + + result = self._execute_job(job_for_execution, db, request_data) + + assert not result["success"] + env["provider"].execute.assert_not_called() + class TestResolveConfigBlob: """Test suite for resolve_config_blob function."""