diff --git a/.env.test b/.env.test deleted file mode 100644 index 40e9ad8..0000000 --- a/.env.test +++ /dev/null @@ -1,31 +0,0 @@ -DOMAIN=localhost - -ENVIRONMENT=testing - -PROJECT_NAME="Kaapi-Guardrails" -STACK_NAME=Kaapi-Guardrails - -# API Base URL for cron scripts (defaults to http://localhost:8000 if not set) -API_BASE_URL=http://localhost:8000 - -# Postgres -POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails -POSTGRES_USER=postgres -POSTGRES_PASSWORD=postgres - -SENTRY_DSN= - -# Configure these with your own Docker registry images - -DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend - -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - -# require as a env if you want to use doc transformation -OPENAI_API_KEY="" -GUARDRAILS_HUB_API_KEY="" -AUTH_TOKEN="" \ No newline at end of file diff --git a/.env.test.example b/.env.test.example index 40e9ad8..368b7d7 100644 --- a/.env.test.example +++ b/.env.test.example @@ -10,8 +10,8 @@ API_BASE_URL=http://localhost:8000 # Postgres POSTGRES_SERVER=localhost +POSTGRES_DB=kaapi_guardrails_testing POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres diff --git a/backend/app/alembic/versions/001_added_request_log.py b/backend/app/alembic/versions/001_added_request_log.py index 706d504..5f2e404 100644 --- a/backend/app/alembic/versions/001_added_request_log.py +++ b/backend/app/alembic/versions/001_added_request_log.py @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision: str = '001' -down_revision: Union[str, Sequence[str], None] = None +down_revision: str | None = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/app/alembic/versions/002_added_validator_log.py b/backend/app/alembic/versions/002_added_validator_log.py index d46513f..6b50229 100644 --- a/backend/app/alembic/versions/002_added_validator_log.py +++ b/backend/app/alembic/versions/002_added_validator_log.py @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision: str = '002' -down_revision: Union[str, Sequence[str], None] = '001' +down_revision: str = '001' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py new file mode 100644 index 0000000..72889b0 --- /dev/null +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -0,0 +1,50 @@ +"""Added validator_config table + +Revision ID: 003 +Revises: 002 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '003' +down_revision: str = '002' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('validator_config', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('organization_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('type', sa.String(), nullable=False), + sa.Column('stage', sa.String(), nullable=False), + sa.Column('on_fail_action', sa.String(), nullable=False), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'{}'::jsonb"), + ), + sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('organization_id', 'project_id', 'type', 'stage', name='uq_validator_identity') + ) + + op.create_index("idx_validator_organization", "validator_config", ["organization_id"]) + op.create_index("idx_validator_project", "validator_config", ["project_id"]) + op.create_index("idx_validator_type", "validator_config", ["type"]) + op.create_index("idx_validator_stage", "validator_config", ["stage"]) + + +def downgrade() -> None: + op.drop_table('validator_config') diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 0fc8026..bf78ade 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,10 +1,11 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails +from app.api.routes import utils, guardrails, validator_configs api_router = APIRouter() api_router.include_router(utils.router) api_router.include_router(guardrails.router) +api_router.include_router(validator_configs.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 750ac71..90f7b3c 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -8,11 +8,11 @@ from app.api.deps import AuthDep, SessionDep from app.core.constants import REPHRASE_ON_FAIL_PREFIX from app.core.guardrail_controller import build_guard, get_validator_config_models -from app.crud.request_log import RequestLogCrud -from app.crud.validator_log import ValidatorLogCrud -from app.models.guardrail_config import GuardrailRequest, GuardrailResponse -from app.models.logging.request import RequestLogUpdate, RequestStatus -from app.models.logging.validator import ValidatorLog, ValidatorOutcome +from app.crud.request_log_repo import RequestLogCrud +from app.crud.validator_log_repo import ValidatorLogCrud +from app.models.logging.request_log_table import RequestLogUpdate, RequestStatus +from app.models.logging.validator_log_table import ValidatorLog, ValidatorOutcome +from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse router = APIRouter(prefix="/guardrails", tags=["guardrails"]) diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py new file mode 100644 index 0000000..feedad9 --- /dev/null +++ b/backend/app/api/routes/validator_configs.py @@ -0,0 +1,90 @@ +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter + +from app.api.deps import AuthDep, SessionDep +from app.core.enum import Stage, ValidatorType +from app.schemas.validator_config import ValidatorCreate, ValidatorResponse, ValidatorUpdate +from app.crud.validator_config_crud import validator_config_crud + + +router = APIRouter( + prefix="/guardrails/validators/configs", + tags=["validator configs"], +) + + +@router.post( + "/", + response_model=ValidatorResponse + ) +async def create_validator( + payload: ValidatorCreate, + session: SessionDep, + organization_id: int, + project_id: int, + _: AuthDep, +): + return validator_config_crud.create(session, organization_id, project_id, payload) +@router.get( + "/", + response_model=list[ValidatorResponse] + ) +async def list_validators( + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, +): + return validator_config_crud.list(session, organization_id, project_id, stage, type) + + +@router.get( + "/{id}", + response_model=ValidatorResponse + ) +async def get_validator( + id: UUID, + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = validator_config_crud.get_or_404(session, id, organization_id, project_id) + return validator_config_crud.flatten(obj) + + +@router.patch( + "/{id}", + response_model=ValidatorResponse + ) +async def update_validator( + id: UUID, + organization_id: int, + project_id: int, + payload: ValidatorUpdate, + session: SessionDep, + _: AuthDep, +): + obj = validator_config_crud.get_or_404(session, id, organization_id, project_id) + return validator_config_crud.update( + session, + obj, + payload.model_dump(exclude_unset=True), + ) + + +@router.delete("/{id}") +async def delete_validator( + id: UUID, + organization_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = validator_config_crud.get_or_404(session, id, organization_id, project_id) + validator_config_crud.delete(session, obj) + return {"success": True} diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index d6e3a7a..6c3825d 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -6,3 +6,12 @@ SCORE = "score" REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." + +VALIDATOR_CONFIG_SYSTEM_FIELDS = { + "organization_id", + "project_id", + "type", + "stage", + "on_fail_action", + "is_enabled", +} diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index 38418e9..6b8351f 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -15,4 +15,14 @@ class BiasCategories(Enum): class GuardrailOnFail(Enum): Exception = "exception" Fix = "fix" - Rephrase = "rephrase" \ No newline at end of file + Rephrase = "rephrase" + +class Stage(Enum): + Input = "input" + Output = "output" + +class ValidatorType(Enum): + LexicalSlur = "uli_slur_match" + PIIRemover = "pii_remover" + GenderAssumptionBias = "gender_assumption_bias" + BanList = "ban_list" diff --git a/backend/app/core/guardrail_controller.py b/backend/app/core/guardrail_controller.py index a935636..c4578e0 100644 --- a/backend/app/core/guardrail_controller.py +++ b/backend/app/core/guardrail_controller.py @@ -2,7 +2,7 @@ from guardrails import Guard -from app.models.guardrail_config import ValidatorConfigItem +from app.schemas.guardrail_config import ValidatorConfigItem def build_guard(validator_items): validators = [v_item.build() for v_item in validator_items] diff --git a/backend/app/core/validators/__init__.py b/backend/app/core/validators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/validators/ban_list_safety_validator_config.py b/backend/app/core/validators/config/ban_list_safety_validator_config.py similarity index 81% rename from backend/app/models/validators/ban_list_safety_validator_config.py rename to backend/app/core/validators/config/ban_list_safety_validator_config.py index 4a853f0..260399e 100644 --- a/backend/app/models/validators/ban_list_safety_validator_config.py +++ b/backend/app/core/validators/config/ban_list_safety_validator_config.py @@ -2,7 +2,7 @@ from guardrails.hub import BanList -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class BanListSafetyValidatorConfig(BaseValidatorConfig): type: Literal["ban_list"] diff --git a/backend/app/models/base_validator_config.py b/backend/app/core/validators/config/base_validator_config.py similarity index 100% rename from backend/app/models/base_validator_config.py rename to backend/app/core/validators/config/base_validator_config.py diff --git a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py similarity index 86% rename from backend/app/models/validators/gender_assumption_bias_safety_validator_config.py rename to backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py index 116c281..7cd3687 100644 --- a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py +++ b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py @@ -1,8 +1,8 @@ from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.enum import BiasCategories from app.core.validators.gender_assumption_bias import GenderAssumptionBias +from app.core.validators.config.base_validator_config import BaseValidatorConfig class GenderAssumptionBiasSafetyValidatorConfig(BaseValidatorConfig): type: Literal["gender_assumption_bias"] diff --git a/backend/app/models/validators/lexical_slur_safety_validator_config.py b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py similarity index 88% rename from backend/app/models/validators/lexical_slur_safety_validator_config.py rename to backend/app/core/validators/config/lexical_slur_safety_validator_config.py index 6378182..d86c0d6 100644 --- a/backend/app/models/validators/lexical_slur_safety_validator_config.py +++ b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py @@ -2,7 +2,7 @@ from app.core.enum import SlurSeverity from app.core.validators.lexical_slur import LexicalSlur -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class LexicalSlurSafetyValidatorConfig(BaseValidatorConfig): type: Literal["uli_slur_match"] diff --git a/backend/app/models/validators/pii_remover_safety_validator_config.py b/backend/app/core/validators/config/pii_remover_safety_validator_config.py similarity index 87% rename from backend/app/models/validators/pii_remover_safety_validator_config.py rename to backend/app/core/validators/config/pii_remover_safety_validator_config.py index d8d3a18..fc18fa5 100644 --- a/backend/app/models/validators/pii_remover_safety_validator_config.py +++ b/backend/app/core/validators/config/pii_remover_safety_validator_config.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.validators.pii_remover import PIIRemover +from app.core.validators.config.base_validator_config import BaseValidatorConfig class PIIRemoverSafetyValidatorConfig(BaseValidatorConfig): type: Literal["pii_remover"] diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index c955a67..a6ddc7b 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -1 +1 @@ -from app.crud.request_log import RequestLogCrud \ No newline at end of file +from app.crud.request_log_repo import RequestLogCrud diff --git a/backend/app/crud/request_log.py b/backend/app/crud/request_log_repo.py similarity index 93% rename from backend/app/crud/request_log.py rename to backend/app/crud/request_log_repo.py index 74d5ece..9d3b1e5 100644 --- a/backend/app/crud/request_log.py +++ b/backend/app/crud/request_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.request import RequestLog, RequestLogUpdate, RequestStatus +from app.models.logging.request_log_table import RequestLog, RequestLogUpdate, RequestStatus from app.utils import now class RequestLogCrud: diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py new file mode 100644 index 0000000..8b7b4ad --- /dev/null +++ b/backend/app/crud/validator_config_crud.py @@ -0,0 +1,111 @@ +from typing import Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.core.enum import Stage, ValidatorType +from app.models.config.validator_config_table import ValidatorConfig +from app.schemas.validator_config import ValidatorCreate +from app.utils import now, split_validator_payload + + +class ValidatorConfigCrud: + def create( + self, + session: Session, + organization_id: int, + project_id: int, + payload: ValidatorCreate + ): + data = payload.model_dump() + model_fields, config_fields = split_validator_payload(data) + + obj = ValidatorConfig( + organization_id=organization_id, + project_id=project_id, + config=config_fields, + **model_fields, + ) + + session.add(obj) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + + session.refresh(obj) + return self.flatten(obj) + + def list( + self, + session: Session, + organization_id: int, + project_id: int, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, + ) -> list[dict]: + query = select(ValidatorConfig).where( + ValidatorConfig.organization_id == organization_id, + ValidatorConfig.project_id == project_id, + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [self.flatten(r) for r in rows] + + def get_or_404( + self, + session: Session, + id: UUID, + organization_id: int, + project_id: int, + ) -> ValidatorConfig: + obj = session.get(ValidatorConfig, id) + + if not obj or obj.organization_id != organization_id or obj.project_id != project_id: + raise HTTPException(404, "Validator not found") + + return obj + + def update( + self, + session: Session, + obj: ValidatorConfig, + update_data: dict + ): + model_fields, config_fields = split_validator_payload(update_data) + + for k, v in model_fields.items(): + setattr(obj, k, v) + + if config_fields: + obj.config = {**(obj.config or {}), **config_fields} + + obj.updated_at = now() + session.commit() + session.refresh(obj) + + return self.flatten(obj) + + def delete(self, session: Session, obj: ValidatorConfig): + session.delete(obj) + session.commit() + + def flatten(self, row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + return {**base, **(row.config or {})} + + +validator_config_crud = ValidatorConfigCrud() diff --git a/backend/app/crud/validator_log.py b/backend/app/crud/validator_log_repo.py similarity index 85% rename from backend/app/crud/validator_log.py rename to backend/app/crud/validator_log_repo.py index 6eb1c1a..649d6aa 100644 --- a/backend/app/crud/validator_log.py +++ b/backend/app/crud/validator_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.validator import ValidatorLog +from app.models.logging.validator_log_table import ValidatorLog from app.utils import now class ValidatorLogCrud: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5672003..2ba735f 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,4 +1,4 @@ from sqlmodel import SQLModel -from app.models.logging.request import RequestLog -from app.models.logging.validator import ValidatorLog +from app.models.logging.request_log_table import RequestLog +from app.models.logging.validator_log_table import ValidatorLog diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py new file mode 100644 index 0000000..5c4a7c1 --- /dev/null +++ b/backend/app/models/config/validator_config_table.py @@ -0,0 +1,85 @@ +from datetime import datetime +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy import Column, UniqueConstraint +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel, Field + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.utils import now + +class ValidatorConfig(SQLModel, table=True): + __tablename__ = "validator_config" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the validator configuration"}, + ) + + organization_id: int = Field( + index=True, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + type: ValidatorType = Field( + nullable=False, + sa_column_kwargs={"comment": "Type of the validator"}, + ) + + stage: Stage = Field( + nullable=False, + sa_column_kwargs={"comment": "Stage at which the validator is applied"}, + ) + + on_fail_action: GuardrailOnFail = Field( + default=GuardrailOnFail.Fix, + nullable=False, + sa_column_kwargs={"comment": "Action to take when the validator fails"}, + ) + + config: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=False, + comment="Configuration for the validator", + ), + description=( + "Configuration for the validator" + ), + ) + + is_enabled: bool = Field( + default=True, + sa_column_kwargs={"comment": "Indicates if the validator is enabled"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was inserted"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the validator config was last updated", + "onupdate": now, + }, + ) + + __table_args__ = ( + UniqueConstraint( + "organization_id", "project_id", "type", "stage", + name="uq_validator_identity" + ), + ) diff --git a/backend/app/models/logging/request.py b/backend/app/models/logging/request_log_table.py similarity index 100% rename from backend/app/models/logging/request.py rename to backend/app/models/logging/request_log_table.py diff --git a/backend/app/models/logging/validator.py b/backend/app/models/logging/validator_log_table.py similarity index 100% rename from backend/app/models/logging/validator.py rename to backend/app/models/logging/validator_log_table.py diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/guardrail_config.py b/backend/app/schemas/guardrail_config.py similarity index 64% rename from backend/app/models/guardrail_config.py rename to backend/app/schemas/guardrail_config.py index bfe36a6..6dfde4c 100644 --- a/backend/app/models/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -5,10 +5,10 @@ # todo this could be improved by having some auto-discovery mechanism inside # validators. We'll not have to list every new validator like this. -from app.models.validators.ban_list_safety_validator_config import BanListSafetyValidatorConfig -from app.models.validators.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig -from app.models.validators.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig -from app.models.validators.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig +from app.core.validators.config.ban_list_safety_validator_config import BanListSafetyValidatorConfig +from app.core.validators.config.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig +from app.core.validators.config.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig +from app.core.validators.config.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig ValidatorConfigItem = Annotated[ # future validators will come here diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py new file mode 100644 index 0000000..7e693b5 --- /dev/null +++ b/backend/app/schemas/validator_config.py @@ -0,0 +1,38 @@ +from datetime import datetime +from typing import Optional +from uuid import UUID + +from sqlmodel import SQLModel + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType + + +class ValidatorBase(SQLModel): + model_config = {"extra": "allow"} + + type: ValidatorType + stage: Stage + on_fail_action: GuardrailOnFail + is_enabled: bool = True + + +class ValidatorCreate(ValidatorBase): + pass + + +class ValidatorUpdate(SQLModel): + # also allow extras for partial updates + model_config = {"extra": "allow"} + + type: Optional[ValidatorType] = None + stage: Optional[Stage] = None + on_fail_action: Optional[GuardrailOnFail] = None + is_enabled: Optional[bool] = None + + +class ValidatorResponse(ValidatorBase): + id: UUID + organization_id: int + project_id: int + created_at: datetime + updated_at: datetime diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index f76b33d..d595b97 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -1,50 +1,43 @@ +# conftest.py import os -from unittest.mock import MagicMock +os.environ["ENVIRONMENT"] = "testing" import pytest from fastapi.testclient import TestClient +from sqlmodel import Session, create_engine, SQLModel -# MUST be set before app import -os.environ["ENVIRONMENT"] = "testing" - -from app.api.deps import SessionDep, verify_bearer_token -from app.api.routes import guardrails from app.main import app +from app.api.deps import SessionDep, verify_bearer_token +from app.core.config import settings + +test_engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + echo=False, + pool_pre_ping=True, +) + +def override_session(): + with Session(test_engine) as session: + yield session + +@pytest.fixture(scope="session", autouse=True) +def setup_test_db(): + SQLModel.metadata.create_all(test_engine) + yield + SQLModel.metadata.drop_all(test_engine) + +@pytest.fixture(scope="function", autouse=True) +def clean_db(): + with Session(test_engine) as session: + for table in reversed(SQLModel.metadata.sorted_tables): + session.execute(table.delete()) + session.commit() @pytest.fixture(scope="function", autouse=True) -def override_dependencies(monkeypatch): - """ - Override ALL external dependencies: - - Auth - - DB session - - CRUDs - """ - - # ---- Auth override ---- +def override_dependencies(): app.dependency_overrides[verify_bearer_token] = lambda: True - # ---- DB session override ---- - mock_session = MagicMock() - app.dependency_overrides[SessionDep] = lambda: mock_session - - # ---- CRUD override ---- - mock_request_log_crud = MagicMock() - mock_request_log_crud.create.return_value = MagicMock(id=1) - mock_request_log_crud.update.return_value = None - - mock_validator_log_crud = MagicMock() - mock_validator_log_crud.create.return_value = None - - monkeypatch.setattr( - guardrails, - "RequestLogCrud", - lambda session: mock_request_log_crud, - ) - monkeypatch.setattr( - guardrails, - "ValidatorLogCrud", - lambda session: mock_validator_log_crud, - ) + app.dependency_overrides[SessionDep] = override_session yield diff --git a/backend/app/tests/test_guardrails_api_integration.py b/backend/app/tests/test_guardrails_api_integration.py index 48ca1e8..485b532 100644 --- a/backend/app/tests/test_guardrails_api_integration.py +++ b/backend/app/tests/test_guardrails_api_integration.py @@ -70,7 +70,7 @@ def test_input_guardrails_with_lexical_slur(integration_client): body = response.json() assert body["success"] is True - assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]" + assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]." def test_input_guardrails_with_lexical_slur_clean_text(integration_client): @@ -123,7 +123,7 @@ def test_input_guardrails_with_multiple_validators(integration_client): assert body["success"] is True assert ( body["data"][SAFE_TEXT_FIELD] - == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus" + == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus." ) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py new file mode 100644 index 0000000..fc3c673 --- /dev/null +++ b/backend/app/tests/test_validator_configs.py @@ -0,0 +1,135 @@ +import uuid +from unittest.mock import MagicMock + +import pytest +from sqlmodel import Session + +from app.crud.validator_config_crud import validator_config_crud +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.models.config.validator_config_table import ValidatorConfig + +# Test data constants +TEST_ORGANIZATION_ID = 1 +TEST_PROJECT_ID = 1 +TEST_VALIDATOR_ID = uuid.uuid4() +TEST_TYPE = ValidatorType.LexicalSlur +TEST_STAGE = Stage.Input +TEST_ON_FAIL = GuardrailOnFail.Fix + +@pytest.fixture +def mock_session(): + """Create a mock session for database operations.""" + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_validator(): + """Create a sample validator config for testing.""" + return ValidatorConfig( + id=TEST_VALIDATOR_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={"severity": "all", "languages": ["en", "hi"]}, + ) + + +class TestFlatten: + def test_flatten_includes_config_fields(self, sample_validator): + result = validator_config_crud.flatten(sample_validator) + + assert result["severity"] == "all" + assert result["languages"] == ["en", "hi"] + assert result["id"] == TEST_VALIDATOR_ID + + def test_flatten_empty_config(self): + validator = ValidatorConfig( + id=TEST_VALIDATOR_ID, + organization_id=TEST_ORGANIZATION_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={}, + ) + + result = validator_config_crud.flatten(validator) + + assert "severity" not in result + + +class TestGetOr404: + def test_success(self, sample_validator, mock_session): + mock_session.get.return_value = sample_validator + + result = validator_config_crud.get_or_404( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORGANIZATION_ID, + TEST_PROJECT_ID, + ) + + assert result == sample_validator + mock_session.get.assert_called_once() + + def test_not_found(self, mock_session): + mock_session.get.return_value = None + + with pytest.raises(Exception) as exc: + validator_config_crud.get_or_404( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORGANIZATION_ID, + TEST_PROJECT_ID, + ) + + assert "Validator not found" in str(exc.value) + + +class TestUpdate: + def test_update_base_fields(self, sample_validator, mock_session): + update_data = { + "type": ValidatorType.PIIRemover, + "on_fail_action": GuardrailOnFail.Exception, + } + + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) + + assert result["type"] == ValidatorType.PIIRemover + assert result["on_fail_action"] == GuardrailOnFail.Exception + + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + def test_update_extra_fields(self, sample_validator, mock_session): + update_data = {"severity": "high", "new_field": "new_value"} + + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) + + assert result["severity"] == "high" + assert result["new_field"] == "new_value" + assert result["languages"] == ["en", "hi"] + + def test_merge_config(self, sample_validator, mock_session): + sample_validator.config = {"severity": "all", "languages": ["en"]} + + result = validator_config_crud.update( + mock_session, + sample_validator, + {"languages": ["en", "hi"]}, + ) + + assert result["languages"] == ["en", "hi"] + assert result["severity"] == "all" diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py new file mode 100644 index 0000000..6965796 --- /dev/null +++ b/backend/app/tests/test_validator_configs_integration.py @@ -0,0 +1,301 @@ +import uuid + +import pytest +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.validator_config_table import ValidatorConfig + +pytestmark = pytest.mark.integration + +# Test data constants +TEST_ORGANIZATION_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/validators/configs/" +DEFAULT_QUERY_PARAMS = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + +VALIDATOR_PAYLOADS = { + "lexical_slur": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + "severity": "all", + "languages": ["en", "hi"], + }, + "pii_remover_input": { + "type": "pii_remover", + "stage": "input", + "on_fail_action": "fix", + }, + "pii_remover_output": { + "type": "pii_remover", + "stage": "output", + "on_fail_action": "fix", + }, + "minimal": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + }, +} + + +@pytest.fixture +def clear_database(): + """Clear ValidatorConfig table before and after each test.""" + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + +class BaseValidatorTest: + """Base class with helper methods for validator tests.""" + + def create_validator(self, client, payload_key="minimal", **kwargs): + """Helper to create a validator.""" + payload = {**VALIDATOR_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", json=payload) + + def get_validator(self, client, validator_id): + """Helper to get a specific validator.""" + return client.get(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + def list_validators(self, client, **query_params): + """Helper to list validators with optional filters.""" + params_str = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" + if query_params: + params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) + return client.get(f"{BASE_URL}{params_str}") + + def update_validator(self, client, validator_id, payload): + """Helper to update a validator.""" + return client.patch(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}", json=payload) + + def delete_validator(self, client, validator_id): + """Helper to delete a validator.""" + return client.delete(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + +class TestCreateValidator(BaseValidatorTest): + """Tests for POST /guardrails/validators/configs endpoint.""" + + def test_create_validator_success(self, integration_client, clear_database): + """Test successful validator creation.""" + response = self.create_validator(integration_client, "lexical_slur") + + assert response.status_code == 200 + data = response.json() + assert data["type"] == "uli_slur_match" + assert data["stage"] == "input" + assert data["severity"] == "all" + assert data["languages"] == ["en", "hi"] + assert "id" in data + + def test_create_validator_duplicate_raises_400(self, integration_client, clear_database): + """Test that creating duplicate validator raises 400.""" + # First request should succeed + response1 = self.create_validator(integration_client, "minimal") + assert response1.status_code == 200 + + # Second request with same unique keys should fail + response2 = self.create_validator(integration_client, "minimal") + assert response2.status_code == 400 + + def test_create_validator_missing_required_fields(self, integration_client, clear_database): + """Test that missing required fields returns validation error.""" + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", + json={"type": "uli_slur_match"}, + ) + + assert response.status_code == 422 + + +class TestListValidators(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs endpoint.""" + + def test_list_validators_success(self, integration_client, clear_database): + """Test successful validator listing.""" + # Create validators first + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_list_validators_filter_by_stage(self, integration_client, clear_database): + """Test filtering validators by stage.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_output") + + response = self.list_validators(integration_client, stage="input") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["stage"] == "input" + + def test_list_validators_filter_by_type(self, integration_client, clear_database): + """Test filtering validators by type.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client, type="pii_remover") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["type"] == "pii_remover" + + def test_list_validators_empty(self, integration_client, clear_database): + """Test listing validators when none exist.""" + response = integration_client.get( + f"{BASE_URL}?organization_id=999&project_id=999", + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 0 + + +class TestGetValidator(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs/{id} endpoint.""" + + def test_get_validator_success(self, integration_client, clear_database): + """Test successful validator retrieval.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["id"] + + # Retrieve it + response = self.get_validator(integration_client, validator_id) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == validator_id + assert data["severity"] == "all" + + def test_get_validator_not_found(self, integration_client, clear_database): + """Test retrieving non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.get_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_get_validator_wrong_org(self, integration_client, clear_database): + """Test that accessing validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Try to access it as different org + response = integration_client.get( + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", + ) + + assert response.status_code == 404 + + +class TestUpdateValidator(BaseValidatorTest): + """Tests for PATCH /guardrails/validators/configs/{id} endpoint.""" + + def test_update_validator_success(self, integration_client, clear_database): + """Test successful validator update.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["id"] + + # Update it + update_payload = {"on_fail_action": "exception", "severity": "high"} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json() + assert data["on_fail_action"] == "exception" + assert data["severity"] == "high" + + def test_update_validator_partial(self, integration_client, clear_database): + """Test partial update preserves original fields.""" + # Create a validator + create_response = self.create_validator( + integration_client, + "lexical_slur", + severity="all", + languages=["en", "hi"], + ) + validator_id = create_response.json()["id"] + + # Update only one field + update_payload = {"severity": "low"} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json() + assert data["severity"] == "low" + assert data["languages"] == ["en", "hi"] # Original preserved + + def test_update_validator_not_found(self, integration_client, clear_database): + """Test updating non-existent validator returns 404.""" + fake_id = uuid.uuid4() + update_payload = {"severity": "low"} + + response = self.update_validator(integration_client, fake_id, update_payload) + + assert response.status_code == 404 + + +class TestDeleteValidator(BaseValidatorTest): + """Tests for DELETE /guardrails/validators/configs/{id} endpoint.""" + + def test_delete_validator_success(self, integration_client, clear_database): + """Test successful validator deletion.""" + # Create a validator + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Delete it + response = self.delete_validator(integration_client, validator_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + # Verify it's deleted + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 404 + + def test_delete_validator_not_found(self, integration_client, clear_database): + """Test deleting non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.delete_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_delete_validator_wrong_org(self, integration_client, clear_database): + """Test that deleting validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Try to delete it as different org + response = integration_client.delete( + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", + ) + + assert response.status_code == 404 + + # Verify original is still there + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 200 diff --git a/backend/app/tests/utils/constants.py b/backend/app/tests/utils/constants.py index 568bc19..e642c67 100644 --- a/backend/app/tests/utils/constants.py +++ b/backend/app/tests/utils/constants.py @@ -1,2 +1,2 @@ -VALIDATE_API_PATH = "/api/v1/guardrails/validate/" +VALIDATE_API_PATH = "/api/v1/guardrails/" SAFE_TEXT_FIELD = "safe_text" diff --git a/backend/app/utils.py b/backend/app/utils.py index 4e10f52..4df593b 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -3,6 +3,8 @@ from pydantic import BaseModel from typing import Any, Dict, Generic, Optional, TypeVar +from app.core.constants import VALIDATOR_CONFIG_SYSTEM_FIELDS as SYSTEM_FIELDS + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -11,6 +13,19 @@ def now(): return datetime.now(timezone.utc).replace(tzinfo=None) +def split_validator_payload(data: dict): + model_fields = {} + config_fields = {} + + for key, value in data.items(): + if key in SYSTEM_FIELDS: + model_fields[key] = value + else: + config_fields[key] = value + + return model_fields, config_fields + + class APIResponse(BaseModel, Generic[T]): success: bool data: Optional[T] = None