From a537bf03a02d6affc9f60a70997c58557a31cc49 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Wed, 4 Feb 2026 14:04:22 +0530 Subject: [PATCH 01/10] Rearranged files in folder --- backend/app/api/routes/guardrails.py | 10 +++++----- backend/app/core/enum.py | 12 +++++++++++- backend/app/core/guardrail_controller.py | 2 +- backend/app/core/validators/__init__.py | 0 .../config}/ban_list_safety_validator_config.py | 2 +- .../validators/config}/base_validator_config.py | 0 ...gender_assumption_bias_safety_validator_config.py | 2 +- .../config}/lexical_slur_safety_validator_config.py | 2 +- .../config}/pii_remover_safety_validator_config.py | 2 +- backend/app/crud/__init__.py | 2 +- .../app/crud/{request_log.py => request_log_repo.py} | 2 +- .../crud/{validator_log.py => validator_log_repo.py} | 2 +- backend/app/models/__init__.py | 4 ++-- .../logging/{request.py => request_log_table.py} | 0 .../logging/{validator.py => validator_log_table.py} | 0 backend/app/schemas/__init__.py | 0 backend/app/{models => schemas}/guardrail_config.py | 8 ++++---- 17 files changed, 30 insertions(+), 20 deletions(-) create mode 100644 backend/app/core/validators/__init__.py rename backend/app/{models/validators => core/validators/config}/ban_list_safety_validator_config.py (81%) rename backend/app/{models => core/validators/config}/base_validator_config.py (100%) rename backend/app/{models/validators => core/validators/config}/gender_assumption_bias_safety_validator_config.py (86%) rename backend/app/{models/validators => core/validators/config}/lexical_slur_safety_validator_config.py (88%) rename backend/app/{models/validators => core/validators/config}/pii_remover_safety_validator_config.py (87%) rename backend/app/crud/{request_log.py => request_log_repo.py} (93%) rename backend/app/crud/{validator_log.py => validator_log_repo.py} (85%) rename backend/app/models/logging/{request.py => request_log_table.py} (100%) rename backend/app/models/logging/{validator.py => validator_log_table.py} (100%) create mode 100644 backend/app/schemas/__init__.py rename backend/app/{models => schemas}/guardrail_config.py (64%) diff --git a/backend/app/api/routes/guardrails.py b/backend/app/api/routes/guardrails.py index 750ac71..90f7b3c 100644 --- a/backend/app/api/routes/guardrails.py +++ b/backend/app/api/routes/guardrails.py @@ -8,11 +8,11 @@ from app.api.deps import AuthDep, SessionDep from app.core.constants import REPHRASE_ON_FAIL_PREFIX from app.core.guardrail_controller import build_guard, get_validator_config_models -from app.crud.request_log import RequestLogCrud -from app.crud.validator_log import ValidatorLogCrud -from app.models.guardrail_config import GuardrailRequest, GuardrailResponse -from app.models.logging.request import RequestLogUpdate, RequestStatus -from app.models.logging.validator import ValidatorLog, ValidatorOutcome +from app.crud.request_log_repo import RequestLogCrud +from app.crud.validator_log_repo import ValidatorLogCrud +from app.models.logging.request_log_table import RequestLogUpdate, RequestStatus +from app.models.logging.validator_log_table import ValidatorLog, ValidatorOutcome +from app.schemas.guardrail_config import GuardrailRequest, GuardrailResponse from app.utils import APIResponse router = APIRouter(prefix="/guardrails", tags=["guardrails"]) diff --git a/backend/app/core/enum.py b/backend/app/core/enum.py index 38418e9..6b8351f 100644 --- a/backend/app/core/enum.py +++ b/backend/app/core/enum.py @@ -15,4 +15,14 @@ class BiasCategories(Enum): class GuardrailOnFail(Enum): Exception = "exception" Fix = "fix" - Rephrase = "rephrase" \ No newline at end of file + Rephrase = "rephrase" + +class Stage(Enum): + Input = "input" + Output = "output" + +class ValidatorType(Enum): + LexicalSlur = "uli_slur_match" + PIIRemover = "pii_remover" + GenderAssumptionBias = "gender_assumption_bias" + BanList = "ban_list" diff --git a/backend/app/core/guardrail_controller.py b/backend/app/core/guardrail_controller.py index a935636..c4578e0 100644 --- a/backend/app/core/guardrail_controller.py +++ b/backend/app/core/guardrail_controller.py @@ -2,7 +2,7 @@ from guardrails import Guard -from app.models.guardrail_config import ValidatorConfigItem +from app.schemas.guardrail_config import ValidatorConfigItem def build_guard(validator_items): validators = [v_item.build() for v_item in validator_items] diff --git a/backend/app/core/validators/__init__.py b/backend/app/core/validators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/validators/ban_list_safety_validator_config.py b/backend/app/core/validators/config/ban_list_safety_validator_config.py similarity index 81% rename from backend/app/models/validators/ban_list_safety_validator_config.py rename to backend/app/core/validators/config/ban_list_safety_validator_config.py index 4a853f0..260399e 100644 --- a/backend/app/models/validators/ban_list_safety_validator_config.py +++ b/backend/app/core/validators/config/ban_list_safety_validator_config.py @@ -2,7 +2,7 @@ from guardrails.hub import BanList -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class BanListSafetyValidatorConfig(BaseValidatorConfig): type: Literal["ban_list"] diff --git a/backend/app/models/base_validator_config.py b/backend/app/core/validators/config/base_validator_config.py similarity index 100% rename from backend/app/models/base_validator_config.py rename to backend/app/core/validators/config/base_validator_config.py diff --git a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py similarity index 86% rename from backend/app/models/validators/gender_assumption_bias_safety_validator_config.py rename to backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py index 116c281..7cd3687 100644 --- a/backend/app/models/validators/gender_assumption_bias_safety_validator_config.py +++ b/backend/app/core/validators/config/gender_assumption_bias_safety_validator_config.py @@ -1,8 +1,8 @@ from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.enum import BiasCategories from app.core.validators.gender_assumption_bias import GenderAssumptionBias +from app.core.validators.config.base_validator_config import BaseValidatorConfig class GenderAssumptionBiasSafetyValidatorConfig(BaseValidatorConfig): type: Literal["gender_assumption_bias"] diff --git a/backend/app/models/validators/lexical_slur_safety_validator_config.py b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py similarity index 88% rename from backend/app/models/validators/lexical_slur_safety_validator_config.py rename to backend/app/core/validators/config/lexical_slur_safety_validator_config.py index 6378182..d86c0d6 100644 --- a/backend/app/models/validators/lexical_slur_safety_validator_config.py +++ b/backend/app/core/validators/config/lexical_slur_safety_validator_config.py @@ -2,7 +2,7 @@ from app.core.enum import SlurSeverity from app.core.validators.lexical_slur import LexicalSlur -from app.models.base_validator_config import BaseValidatorConfig +from app.core.validators.config.base_validator_config import BaseValidatorConfig class LexicalSlurSafetyValidatorConfig(BaseValidatorConfig): type: Literal["uli_slur_match"] diff --git a/backend/app/models/validators/pii_remover_safety_validator_config.py b/backend/app/core/validators/config/pii_remover_safety_validator_config.py similarity index 87% rename from backend/app/models/validators/pii_remover_safety_validator_config.py rename to backend/app/core/validators/config/pii_remover_safety_validator_config.py index d8d3a18..fc18fa5 100644 --- a/backend/app/models/validators/pii_remover_safety_validator_config.py +++ b/backend/app/core/validators/config/pii_remover_safety_validator_config.py @@ -1,8 +1,8 @@ from __future__ import annotations from typing import List, Literal, Optional -from app.models.base_validator_config import BaseValidatorConfig from app.core.validators.pii_remover import PIIRemover +from app.core.validators.config.base_validator_config import BaseValidatorConfig class PIIRemoverSafetyValidatorConfig(BaseValidatorConfig): type: Literal["pii_remover"] diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index c955a67..37bc215 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -1 +1 @@ -from app.crud.request_log import RequestLogCrud \ No newline at end of file +from app.crud.request_log_repo import RequestLogCrud \ No newline at end of file diff --git a/backend/app/crud/request_log.py b/backend/app/crud/request_log_repo.py similarity index 93% rename from backend/app/crud/request_log.py rename to backend/app/crud/request_log_repo.py index 74d5ece..9d3b1e5 100644 --- a/backend/app/crud/request_log.py +++ b/backend/app/crud/request_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.request import RequestLog, RequestLogUpdate, RequestStatus +from app.models.logging.request_log_table import RequestLog, RequestLogUpdate, RequestStatus from app.utils import now class RequestLogCrud: diff --git a/backend/app/crud/validator_log.py b/backend/app/crud/validator_log_repo.py similarity index 85% rename from backend/app/crud/validator_log.py rename to backend/app/crud/validator_log_repo.py index 6eb1c1a..649d6aa 100644 --- a/backend/app/crud/validator_log.py +++ b/backend/app/crud/validator_log_repo.py @@ -2,7 +2,7 @@ from sqlmodel import Session -from app.models.logging.validator import ValidatorLog +from app.models.logging.validator_log_table import ValidatorLog from app.utils import now class ValidatorLogCrud: diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5672003..2ba735f 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,4 +1,4 @@ from sqlmodel import SQLModel -from app.models.logging.request import RequestLog -from app.models.logging.validator import ValidatorLog +from app.models.logging.request_log_table import RequestLog +from app.models.logging.validator_log_table import ValidatorLog diff --git a/backend/app/models/logging/request.py b/backend/app/models/logging/request_log_table.py similarity index 100% rename from backend/app/models/logging/request.py rename to backend/app/models/logging/request_log_table.py diff --git a/backend/app/models/logging/validator.py b/backend/app/models/logging/validator_log_table.py similarity index 100% rename from backend/app/models/logging/validator.py rename to backend/app/models/logging/validator_log_table.py diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/models/guardrail_config.py b/backend/app/schemas/guardrail_config.py similarity index 64% rename from backend/app/models/guardrail_config.py rename to backend/app/schemas/guardrail_config.py index bfe36a6..6dfde4c 100644 --- a/backend/app/models/guardrail_config.py +++ b/backend/app/schemas/guardrail_config.py @@ -5,10 +5,10 @@ # todo this could be improved by having some auto-discovery mechanism inside # validators. We'll not have to list every new validator like this. -from app.models.validators.ban_list_safety_validator_config import BanListSafetyValidatorConfig -from app.models.validators.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig -from app.models.validators.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig -from app.models.validators.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig +from app.core.validators.config.ban_list_safety_validator_config import BanListSafetyValidatorConfig +from app.core.validators.config.gender_assumption_bias_safety_validator_config import GenderAssumptionBiasSafetyValidatorConfig +from app.core.validators.config.lexical_slur_safety_validator_config import LexicalSlurSafetyValidatorConfig +from app.core.validators.config.pii_remover_safety_validator_config import PIIRemoverSafetyValidatorConfig ValidatorConfigItem = Annotated[ # future validators will come here From 4174e90b6d656deb72cec24e6af364563a4e33ee Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Wed, 4 Feb 2026 17:12:01 +0530 Subject: [PATCH 02/10] Added validator config management changes --- backend/app/api/main.py | 3 +- backend/app/api/routes/validator_configs.py | 143 ++++++++++++++++++ backend/app/core/constants.py | 9 ++ .../models/config/validator_config_table.py | 76 ++++++++++ backend/app/schemas/validator_config.py | 35 +++++ backend/app/utils.py | 15 ++ 6 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 backend/app/api/routes/validator_configs.py create mode 100644 backend/app/models/config/validator_config_table.py create mode 100644 backend/app/schemas/validator_config.py diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 0fc8026..bf78ade 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,10 +1,11 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails +from app.api.routes import utils, guardrails, validator_configs api_router = APIRouter() api_router.include_router(utils.router) api_router.include_router(guardrails.router) +api_router.include_router(validator_configs.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py new file mode 100644 index 0000000..3fb440a --- /dev/null +++ b/backend/app/api/routes/validator_configs.py @@ -0,0 +1,143 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import select + +from app.api.deps import AuthDep, SessionDep +from app.models.config.validator_config_table import ValidatorConfig +from app.schemas.validator_config import * +from app.utils import split_validator_payload + +router = APIRouter(prefix="/guardrails/validators/configs", tags=["validator configs"]) + + +@router.post( + "/", + response_model=ValidatorResponse + ) +async def create_validator( + payload: ValidatorCreate, + session: SessionDep, + org_id: int, + project_id: int, + _: AuthDep, +): + data = payload.model_dump() + base, config = split_validator_payload(data) + obj = ValidatorConfig( + org_id=org_id, + project_id=project_id, + config=config, + **base, + ) + + session.add(obj) + session.commit() + session.refresh(obj) + return obj + +@router.get( + "/", + response_model=List[ValidatorResponse] + ) +async def list_validators( + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, +): + query = select(ValidatorConfig).where( + ValidatorConfig.org_id == org_id, + ValidatorConfig.project_id == project_id + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [flatten_validator(r) for r in rows] + + +@router.get( + "/{id}", + response_model=ValidatorResponse + ) +async def get_validator( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404) + + return flatten_validator(obj) + + +@router.patch( + "/{id}", + response_model=ValidatorResponse + ) +async def update_validator( + id: UUID, + org_id: int, + project_id: int, + payload: ValidatorUpdate, + session: SessionDep, + _: AuthDep, +): + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404) + + data = payload.model_dump(exclude_unset=True) + base, config = split_validator_payload(data) + + print("base", base) + print("config", config) + for k, v in base.items(): + setattr(obj, k, v) + + if config: + obj.config = {**(obj.config or {}), **config} + + session.add(obj) + session.commit() + session.refresh(obj) + + return flatten_validator(obj) + + +@router.delete("/{id}") +async def delete_validator( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404) + + session.delete(obj) + session.commit() + + return {"success": True} + +def flatten_validator(row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + + print(base) + return {**base, **(row.config or {})} diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index d6e3a7a..115ad21 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -6,3 +6,12 @@ SCORE = "score" REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." + +VALIDATOR_CONFIG_SYSTEM_FIELDS = { + "org_id", + "project_id", + "type", + "stage", + "on_fail_action", + "is_enabled", +} diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py new file mode 100644 index 0000000..8e030e0 --- /dev/null +++ b/backend/app/models/config/validator_config_table.py @@ -0,0 +1,76 @@ +from datetime import datetime +from typing import Any, Optional +from uuid import UUID, uuid4 + +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel, Field + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.utils import now + +class ValidatorConfig(SQLModel, table=True): + __tablename__ = "validator_config" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the validator configuration"}, + ) + + org_id: int = Field( + index=True, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: Optional[int] = Field( + default=None, + index=True, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + type: ValidatorType = Field( + nullable=False, + sa_column_kwargs={"comment": "Type of the validator"}, + ) + + stage: Stage = Field( + nullable=False, + sa_column_kwargs={"comment": "Stage at which the validator is applied"}, + ) + + on_fail_action: GuardrailOnFail = Field( + default=GuardrailOnFail.Fix, + nullable=False, + sa_column_kwargs={"comment": "Action to take when the validator fails"}, + ) + + config: dict[str, Any] = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=False, + comment="Configuration for the validator", + ), + description=( + "Configuration for the validator" + ), + ) + + is_enabled: bool = Field( + default=True, + sa_column_kwargs={"comment": "Indicates if the validator is enabled"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was inserted"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the validator config was last updated"}, + ) diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py new file mode 100644 index 0000000..4452e24 --- /dev/null +++ b/backend/app/schemas/validator_config.py @@ -0,0 +1,35 @@ +from typing import Optional +from uuid import UUID, uuid4 + +from sqlmodel import SQLModel + +from app.core.enum import GuardrailOnFail, Stage, ValidatorType + + +class ValidatorBase(SQLModel): + model_config = {"extra": "allow"} + + type: ValidatorType + stage: Stage + on_fail_action: GuardrailOnFail + is_enabled: bool = True + + +class ValidatorCreate(ValidatorBase): + pass + + +class ValidatorUpdate(SQLModel): + # also allow extras for partial updates + model_config = {"extra": "allow"} + + type: Optional[ValidatorType] = None + stage: Optional[Stage] = None + on_fail_action: Optional[GuardrailOnFail] = None + is_enabled: Optional[bool] = None + + +class ValidatorResponse(ValidatorBase): + id: UUID + org_id: int + project_id: Optional[int] = None diff --git a/backend/app/utils.py b/backend/app/utils.py index 4e10f52..30543d5 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -3,6 +3,8 @@ from pydantic import BaseModel from typing import Any, Dict, Generic, Optional, TypeVar +from app.core.constants import VALIDATOR_CONFIG_SYSTEM_FIELDS as SYSTEM_FIELDS + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -11,6 +13,19 @@ def now(): return datetime.now(timezone.utc).replace(tzinfo=None) +def split_validator_payload(data: dict): + base = {} + config = {} + + for k, v in data.items(): + if k in SYSTEM_FIELDS: + base[k] = v + else: + config[k] = v + + return base, config + + class APIResponse(BaseModel, Generic[T]): success: bool data: Optional[T] = None From f691bcb92b77ffc3849014cc49085b6ebb09f164 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 12:44:46 +0530 Subject: [PATCH 03/10] Updated validator config management API --- .../versions/003_added_validator_config.py | 50 +++++++++ backend/app/api/routes/validator_configs.py | 105 ++++++++++++------ .../models/config/validator_config_table.py | 16 ++- backend/app/schemas/validator_config.py | 4 +- 4 files changed, 131 insertions(+), 44 deletions(-) create mode 100644 backend/app/alembic/versions/003_added_validator_config.py diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py new file mode 100644 index 0000000..2ee47d3 --- /dev/null +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -0,0 +1,50 @@ +"""Added validator_config table + +Revision ID: 003 +Revises: 001 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '003' +down_revision: Union[str, Sequence[str], None] = "002" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('validator_config', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('org_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('type', sa.String(), nullable=False), + sa.Column('stage', sa.String(), nullable=False), + sa.Column('on_fail_action', sa.String(), nullable=False), + sa.Column( + "config", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + server_default=sa.text("'{}'::jsonb"), + ), + sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('org_id', 'project_id', 'type', 'stage', name='uq_validator_identity') + ) + + op.create_index("idx_validator_org", "validator_config", ["org_id"]) + op.create_index("idx_validator_project", "validator_config", ["project_id"]) + op.create_index("idx_validator_type", "validator_config", ["type"]) + op.create_index("idx_validator_stage", "validator_config", ["stage"]) + + +def downgrade() -> None: + op.drop_table('validator_config') diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 3fb440a..e35ce1c 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -1,7 +1,8 @@ from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, HTTPException +from sqlalchemy.exc import IntegrityError from sqlmodel import select from app.api.deps import AuthDep, SessionDep @@ -33,9 +34,18 @@ async def create_validator( ) session.add(obj) - session.commit() + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + status_code=400, + detail="Validator already exists for this type and stage", + ) + session.refresh(obj) - return obj + return flatten_validator(obj) @router.get( "/", @@ -75,11 +85,7 @@ async def get_validator( session: SessionDep, _: AuthDep, ): - obj = session.get(ValidatorConfig, id) - - if not obj or obj.org_id != org_id or obj.project_id != project_id: - raise HTTPException(404) - + obj = get_validator_or_404(id, org_id, project_id, session) return flatten_validator(obj) @@ -95,27 +101,13 @@ async def update_validator( session: SessionDep, _: AuthDep, ): - obj = session.get(ValidatorConfig, id) - - if not obj or obj.org_id != org_id or obj.project_id != project_id: - raise HTTPException(404) - - data = payload.model_dump(exclude_unset=True) - base, config = split_validator_payload(data) - - print("base", base) - print("config", config) - for k, v in base.items(): - setattr(obj, k, v) - - if config: - obj.config = {**(obj.config or {}), **config} - - session.add(obj) - session.commit() - session.refresh(obj) - - return flatten_validator(obj) + obj = get_validator_or_404(id, org_id, project_id, session) + updated_obj = update_validator_config( + obj, + payload.model_dump(exclude_unset=True), + session + ) + return flatten_validator(updated_obj) @router.delete("/{id}") @@ -126,18 +118,57 @@ async def delete_validator( session: SessionDep, _: AuthDep, ): - obj = session.get(ValidatorConfig, id) - - if not obj or obj.org_id != org_id or obj.project_id != project_id: - raise HTTPException(404) - + obj = get_validator_or_404(id, org_id, project_id, session) session.delete(obj) session.commit() - return {"success": True} def flatten_validator(row: ValidatorConfig) -> dict: + """ + Flatten validator config: combines base fields with config dict. + Returns a dict with all fields including config extras. + """ base = row.model_dump(exclude={"config"}) + flattened = {**base, **(row.config or {})} + print("FLATTENED:", flattened) + return flattened - print(base) - return {**base, **(row.config or {})} + +def get_validator_or_404( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, +) -> ValidatorConfig: + """Fetch validator by id, org_id, and project_id, or raise 404.""" + obj = session.query(ValidatorConfig).filter( + ValidatorConfig.id == id, + ValidatorConfig.org_id == org_id, + ValidatorConfig.project_id == project_id + ).first() + + if not obj: + raise HTTPException(status_code=404, detail="Validator not found") + + return obj + + +def update_validator_config( + obj: ValidatorConfig, + update_data: dict, + session: SessionDep, +) -> ValidatorConfig: + """Update validator config fields and return the updated object.""" + base, config = split_validator_payload(update_data) + + for k, v in base.items(): + setattr(obj, k, v) + + if config: + obj.config = {**(obj.config or {}), **config} + + session.add(obj) + session.commit() + session.refresh(obj) + + return obj diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py index 8e030e0..68d57cf 100644 --- a/backend/app/models/config/validator_config_table.py +++ b/backend/app/models/config/validator_config_table.py @@ -1,8 +1,8 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any from uuid import UUID, uuid4 -from sqlalchemy import Column +from sqlalchemy import Column, UniqueConstraint from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField from sqlmodel import SQLModel, Field @@ -24,9 +24,8 @@ class ValidatorConfig(SQLModel, table=True): sa_column_kwargs={"comment": "Identifier for the organization"}, ) - project_id: Optional[int] = Field( - default=None, - index=True, + project_id: int = Field( + nullable=False, sa_column_kwargs={"comment": "Identifier for the project"}, ) @@ -74,3 +73,10 @@ class ValidatorConfig(SQLModel, table=True): nullable=False, sa_column_kwargs={"comment": "Timestamp when the validator config was last updated"}, ) + + __table_args__ = ( + UniqueConstraint( + "org_id", "project_id", "type", "stage", + name="uq_validator_identity" + ), + ) \ No newline at end of file diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py index 4452e24..c8a2ae1 100644 --- a/backend/app/schemas/validator_config.py +++ b/backend/app/schemas/validator_config.py @@ -1,5 +1,5 @@ from typing import Optional -from uuid import UUID, uuid4 +from uuid import UUID from sqlmodel import SQLModel @@ -32,4 +32,4 @@ class ValidatorUpdate(SQLModel): class ValidatorResponse(ValidatorBase): id: UUID org_id: int - project_id: Optional[int] = None + project_id: int From cfe8a849e14dbbe1b9cd0817b810254fff3b33a4 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 15:57:28 +0530 Subject: [PATCH 04/10] Updated tests --- backend/app/tests/conftest.py | 67 ++-- .../tests/test_guardrails_api_integration.py | 4 +- backend/app/tests/test_validator_configs.py | 173 ++++++++++ .../test_validator_configs_integration.py | 302 ++++++++++++++++++ backend/app/tests/utils/constants.py | 2 +- 5 files changed, 508 insertions(+), 40 deletions(-) create mode 100644 backend/app/tests/test_validator_configs.py create mode 100644 backend/app/tests/test_validator_configs_integration.py diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index f76b33d..2ff34d0 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -1,50 +1,43 @@ +# conftest.py import os -from unittest.mock import MagicMock +os.environ["ENVIRONMENT"] = "testing" import pytest from fastapi.testclient import TestClient +from sqlmodel import Session, create_engine, SQLModel -# MUST be set before app import -os.environ["ENVIRONMENT"] = "testing" - -from app.api.deps import SessionDep, verify_bearer_token -from app.api.routes import guardrails from app.main import app +from app.api.deps import SessionDep, verify_bearer_token +from app.core.config import settings + +test_engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + echo=False, + pool_pre_ping=True, +) + +def override_session(): + with Session(test_engine) as session: + yield session + +@pytest.fixture(scope="session", autouse=True) +def setup_test_db(): + SQLModel.metadata.create_all(test_engine) + yield + SQLModel.metadata.drop_all(test_engine) + +@pytest.fixture(scope="function", autouse=True) +def clean_db(): + with Session(test_engine) as session: + for table in reversed(SQLModel.metadata.sorted_tables): + session.exec(table.delete()) + session.commit() @pytest.fixture(scope="function", autouse=True) -def override_dependencies(monkeypatch): - """ - Override ALL external dependencies: - - Auth - - DB session - - CRUDs - """ - - # ---- Auth override ---- +def override_dependencies(): app.dependency_overrides[verify_bearer_token] = lambda: True - # ---- DB session override ---- - mock_session = MagicMock() - app.dependency_overrides[SessionDep] = lambda: mock_session - - # ---- CRUD override ---- - mock_request_log_crud = MagicMock() - mock_request_log_crud.create.return_value = MagicMock(id=1) - mock_request_log_crud.update.return_value = None - - mock_validator_log_crud = MagicMock() - mock_validator_log_crud.create.return_value = None - - monkeypatch.setattr( - guardrails, - "RequestLogCrud", - lambda session: mock_request_log_crud, - ) - monkeypatch.setattr( - guardrails, - "ValidatorLogCrud", - lambda session: mock_validator_log_crud, - ) + app.dependency_overrides[SessionDep] = override_session yield diff --git a/backend/app/tests/test_guardrails_api_integration.py b/backend/app/tests/test_guardrails_api_integration.py index 48ca1e8..485b532 100644 --- a/backend/app/tests/test_guardrails_api_integration.py +++ b/backend/app/tests/test_guardrails_api_integration.py @@ -70,7 +70,7 @@ def test_input_guardrails_with_lexical_slur(integration_client): body = response.json() assert body["success"] is True - assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]" + assert body["data"][SAFE_TEXT_FIELD] == "this sentence contains [REDACTED_SLUR]." def test_input_guardrails_with_lexical_slur_clean_text(integration_client): @@ -123,7 +123,7 @@ def test_input_guardrails_with_multiple_validators(integration_client): assert body["success"] is True assert ( body["data"][SAFE_TEXT_FIELD] - == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus" + == "this sentence contains [REDACTED_SLUR] cause i want a s done to kill the female foetus." ) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py new file mode 100644 index 0000000..08c82ac --- /dev/null +++ b/backend/app/tests/test_validator_configs.py @@ -0,0 +1,173 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from sqlmodel import Session, delete + +from app.api.routes.validator_configs import ( + flatten_validator, + get_validator_or_404, + update_validator_config, +) +from app.core.enum import GuardrailOnFail, Stage, ValidatorType +from app.core.db import engine +from app.models.config.validator_config_table import ValidatorConfig + +# Test data constants +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 1 +TEST_VALIDATOR_ID = uuid.uuid4() +TEST_TYPE = ValidatorType.LexicalSlur +TEST_STAGE = Stage.Input +TEST_ON_FAIL = GuardrailOnFail.Fix + + +@pytest.fixture +def clear_database(): + """Clear ValidatorConfig table before and after each test.""" + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + yield + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + +@pytest.fixture +def mock_session(): + """Create a mock session for database operations.""" + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_validator(): + """Create a sample validator config for testing.""" + return ValidatorConfig( + id=TEST_VALIDATOR_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={"severity": "all", "languages": ["en", "hi"]}, + ) + + +class TestFlattenValidator: + """Tests for flatten_validator helper function.""" + + def test_flatten_validator_includes_config_fields(self, sample_validator): + """Test that flatten_validator includes config fields in output.""" + result = flatten_validator(sample_validator) + + assert result["id"] == TEST_VALIDATOR_ID + assert result["org_id"] == TEST_ORG_ID + assert result["project_id"] == TEST_PROJECT_ID + assert result["type"] == TEST_TYPE + assert result["severity"] == "all" + assert result["languages"] == ["en", "hi"] + + def test_flatten_validator_with_empty_config(self): + """Test flatten_validator with empty config dict.""" + validator = ValidatorConfig( + id=TEST_VALIDATOR_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config={}, + ) + + result = flatten_validator(validator) + + assert result["id"] == TEST_VALIDATOR_ID + assert "severity" not in result + # Base fields: id, org_id, project_id, type, stage, on_fail_action, is_enabled, created_at, updated_at + assert len(result) == 9 + + def test_flatten_validator_with_none_config(self): + """Test flatten_validator with None config.""" + validator = ValidatorConfig( + id=TEST_VALIDATOR_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + type=TEST_TYPE, + stage=TEST_STAGE, + on_fail_action=TEST_ON_FAIL, + is_enabled=True, + config=None, + ) + + result = flatten_validator(validator) + + assert result["id"] == TEST_VALIDATOR_ID + assert "severity" not in result + + +class TestGetValidatorOr404: + """Tests for get_validator_or_404 helper function.""" + + def test_get_validator_success(self, sample_validator, mock_session): + """Test successful validator retrieval.""" + mock_session.query.return_value.filter.return_value.first.return_value = ( + sample_validator + ) + + result = get_validator_or_404( + TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session + ) + + assert result == sample_validator + mock_session.query.assert_called_once_with(ValidatorConfig) + + def test_get_validator_not_found(self, mock_session): + """Test validator not found raises 404.""" + mock_session.query.return_value.filter.return_value.first.return_value = None + + with pytest.raises(Exception) as exc_info: + get_validator_or_404(TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session) + + assert "404" in str(exc_info.value) + + +class TestUpdateValidatorConfig: + """Tests for update_validator_config helper function.""" + + def test_update_validator_config_base_fields(self, sample_validator, mock_session): + """Test updating base validator fields.""" + update_data = { + "type": ValidatorType.PIIRemover, + "on_fail_action": GuardrailOnFail.Exception, + } + + result = update_validator_config(sample_validator, update_data, mock_session) + + assert result.type == ValidatorType.PIIRemover + assert result.on_fail_action == GuardrailOnFail.Exception + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + def test_update_validator_config_extra_fields(self, sample_validator, mock_session): + """Test updating extra config fields.""" + update_data = {"severity": "high", "new_field": "new_value"} + + result = update_validator_config(sample_validator, update_data, mock_session) + + assert result.config["severity"] == "high" + assert result.config["new_field"] == "new_value" + assert result.config["languages"] == ["en", "hi"] # Original values preserved + + def test_update_validator_merges_config(self, sample_validator, mock_session): + """Test that updating config merges with existing config.""" + sample_validator.config = {"severity": "all", "languages": ["en"]} + update_data = {"languages": ["en", "hi", "mr"]} + + result = update_validator_config(sample_validator, update_data, mock_session) + + assert result.config["languages"] == ["en", "hi", "mr"] + assert result.config["severity"] == "all" + diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py new file mode 100644 index 0000000..effffd5 --- /dev/null +++ b/backend/app/tests/test_validator_configs_integration.py @@ -0,0 +1,302 @@ +import uuid + +import pytest +from sqlalchemy.exc import OperationalError +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.validator_config_table import ValidatorConfig + +pytestmark = pytest.mark.integration + +# Test data constants +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/validators/configs/" +DEFAULT_QUERY_PARAMS = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + +VALIDATOR_PAYLOADS = { + "lexical_slur": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + "severity": "all", + "languages": ["en", "hi"], + }, + "pii_remover_input": { + "type": "pii_remover", + "stage": "input", + "on_fail_action": "fix", + }, + "pii_remover_output": { + "type": "pii_remover", + "stage": "output", + "on_fail_action": "fix", + }, + "minimal": { + "type": "uli_slur_match", + "stage": "input", + "on_fail_action": "fix", + }, +} + + +@pytest.fixture +def clear_database(): + """Clear ValidatorConfig table before and after each test.""" + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(ValidatorConfig)) + session.commit() + + +class BaseValidatorTest: + """Base class with helper methods for validator tests.""" + + def create_validator(self, client, payload_key="minimal", **kwargs): + """Helper to create a validator.""" + payload = {**VALIDATOR_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", json=payload) + + def get_validator(self, client, validator_id): + """Helper to get a specific validator.""" + return client.get(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + def list_validators(self, client, **query_params): + """Helper to list validators with optional filters.""" + params_str = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + if query_params: + params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) + return client.get(f"{BASE_URL}{params_str}") + + def update_validator(self, client, validator_id, payload): + """Helper to update a validator.""" + return client.patch(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}", json=payload) + + def delete_validator(self, client, validator_id): + """Helper to delete a validator.""" + return client.delete(f"{BASE_URL}{validator_id}/{DEFAULT_QUERY_PARAMS}") + + +class TestCreateValidator(BaseValidatorTest): + """Tests for POST /guardrails/validators/configs endpoint.""" + + def test_create_validator_success(self, integration_client, clear_database): + """Test successful validator creation.""" + response = self.create_validator(integration_client, "lexical_slur") + + assert response.status_code == 200 + data = response.json() + assert data["type"] == "uli_slur_match" + assert data["stage"] == "input" + assert data["severity"] == "all" + assert data["languages"] == ["en", "hi"] + assert "id" in data + + def test_create_validator_duplicate_raises_400(self, integration_client, clear_database): + """Test that creating duplicate validator raises 400.""" + # First request should succeed + response1 = self.create_validator(integration_client, "minimal") + assert response1.status_code == 200 + + # Second request with same unique keys should fail + response2 = self.create_validator(integration_client, "minimal") + assert response2.status_code == 400 + + def test_create_validator_missing_required_fields(self, integration_client, clear_database): + """Test that missing required fields returns validation error.""" + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY_PARAMS}", + json={"type": "uli_slur_match"}, + ) + + assert response.status_code == 422 + + +class TestListValidators(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs endpoint.""" + + def test_list_validators_success(self, integration_client, clear_database): + """Test successful validator listing.""" + # Create validators first + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + def test_list_validators_filter_by_stage(self, integration_client, clear_database): + """Test filtering validators by stage.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_output") + + response = self.list_validators(integration_client, stage="input") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["stage"] == "input" + + def test_list_validators_filter_by_type(self, integration_client, clear_database): + """Test filtering validators by type.""" + self.create_validator(integration_client, "lexical_slur") + self.create_validator(integration_client, "pii_remover_input") + + response = self.list_validators(integration_client, type="pii_remover") + + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["type"] == "pii_remover" + + def test_list_validators_empty(self, integration_client, clear_database): + """Test listing validators when none exist.""" + response = integration_client.get( + f"{BASE_URL}?org_id=999&project_id=999", + ) + + assert response.status_code == 200 + data = response.json() + assert len(data) == 0 + + +class TestGetValidator(BaseValidatorTest): + """Tests for GET /guardrails/validators/configs/{id} endpoint.""" + + def test_get_validator_success(self, integration_client, clear_database): + """Test successful validator retrieval.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["id"] + + # Retrieve it + response = self.get_validator(integration_client, validator_id) + + assert response.status_code == 200 + data = response.json() + assert data["id"] == validator_id + assert data["severity"] == "all" + + def test_get_validator_not_found(self, integration_client, clear_database): + """Test retrieving non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.get_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_get_validator_wrong_org(self, integration_client, clear_database): + """Test that accessing validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Try to access it as different org + response = integration_client.get( + f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + ) + + assert response.status_code == 404 + + +class TestUpdateValidator(BaseValidatorTest): + """Tests for PATCH /guardrails/validators/configs/{id} endpoint.""" + + def test_update_validator_success(self, integration_client, clear_database): + """Test successful validator update.""" + # Create a validator + create_response = self.create_validator( + integration_client, "lexical_slur", severity="all" + ) + validator_id = create_response.json()["id"] + + # Update it + update_payload = {"on_fail_action": "exception", "severity": "high"} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json() + assert data["on_fail_action"] == "exception" + assert data["severity"] == "high" + + def test_update_validator_partial(self, integration_client, clear_database): + """Test partial update preserves original fields.""" + # Create a validator + create_response = self.create_validator( + integration_client, + "lexical_slur", + severity="all", + languages=["en", "hi"], + ) + validator_id = create_response.json()["id"] + + # Update only one field + update_payload = {"severity": "low"} + response = self.update_validator(integration_client, validator_id, update_payload) + + assert response.status_code == 200 + data = response.json() + assert data["severity"] == "low" + assert data["languages"] == ["en", "hi"] # Original preserved + + def test_update_validator_not_found(self, integration_client, clear_database): + """Test updating non-existent validator returns 404.""" + fake_id = uuid.uuid4() + update_payload = {"severity": "low"} + + response = self.update_validator(integration_client, fake_id, update_payload) + + assert response.status_code == 404 + + +class TestDeleteValidator(BaseValidatorTest): + """Tests for DELETE /guardrails/validators/configs/{id} endpoint.""" + + def test_delete_validator_success(self, integration_client, clear_database): + """Test successful validator deletion.""" + # Create a validator + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Delete it + response = self.delete_validator(integration_client, validator_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + # Verify it's deleted + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 404 + + def test_delete_validator_not_found(self, integration_client, clear_database): + """Test deleting non-existent validator returns 404.""" + fake_id = uuid.uuid4() + response = self.delete_validator(integration_client, fake_id) + + assert response.status_code == 404 + + def test_delete_validator_wrong_org(self, integration_client, clear_database): + """Test that deleting validator from different org returns 404.""" + # Create a validator for org 1 + create_response = self.create_validator(integration_client, "minimal") + validator_id = create_response.json()["id"] + + # Try to delete it as different org + response = integration_client.delete( + f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + ) + + assert response.status_code == 404 + + # Verify original is still there + get_response = self.get_validator(integration_client, validator_id) + assert get_response.status_code == 200 diff --git a/backend/app/tests/utils/constants.py b/backend/app/tests/utils/constants.py index 568bc19..e642c67 100644 --- a/backend/app/tests/utils/constants.py +++ b/backend/app/tests/utils/constants.py @@ -1,2 +1,2 @@ -VALIDATE_API_PATH = "/api/v1/guardrails/validate/" +VALIDATE_API_PATH = "/api/v1/guardrails/" SAFE_TEXT_FIELD = "safe_text" From 3a0fa811e83fed86305927d7d0539b8de4e84bde Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 16:47:54 +0530 Subject: [PATCH 05/10] resolved comments --- backend/app/api/routes/validator_configs.py | 118 +++-------------- backend/app/crud/validator_config_crud.py | 109 ++++++++++++++++ backend/app/tests/test_validator_configs.py | 132 ++++++++------------ 3 files changed, 180 insertions(+), 179 deletions(-) create mode 100644 backend/app/crud/validator_config_crud.py diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index e35ce1c..2829a01 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -1,20 +1,21 @@ from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, HTTPException -from sqlalchemy.exc import IntegrityError -from sqlmodel import select +from fastapi import APIRouter from app.api.deps import AuthDep, SessionDep -from app.models.config.validator_config_table import ValidatorConfig from app.schemas.validator_config import * -from app.utils import split_validator_payload +from app.crud.validator_config_crud import validator_config_crud -router = APIRouter(prefix="/guardrails/validators/configs", tags=["validator configs"]) + +router = APIRouter( + prefix="/guardrails/validators/configs", + tags=["validator configs"], +) @router.post( - "/", + "/", response_model=ValidatorResponse ) async def create_validator( @@ -24,28 +25,7 @@ async def create_validator( project_id: int, _: AuthDep, ): - data = payload.model_dump() - base, config = split_validator_payload(data) - obj = ValidatorConfig( - org_id=org_id, - project_id=project_id, - config=config, - **base, - ) - - session.add(obj) - - try: - session.commit() - except IntegrityError: - session.rollback() - raise HTTPException( - status_code=400, - detail="Validator already exists for this type and stage", - ) - - session.refresh(obj) - return flatten_validator(obj) + return validator_config_crud.create(session, org_id, project_id, payload) @router.get( "/", @@ -59,19 +39,7 @@ async def list_validators( stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, ): - query = select(ValidatorConfig).where( - ValidatorConfig.org_id == org_id, - ValidatorConfig.project_id == project_id - ) - - if stage: - query = query.where(ValidatorConfig.stage == stage) - - if type: - query = query.where(ValidatorConfig.type == type) - - rows = session.exec(query).all() - return [flatten_validator(r) for r in rows] + return validator_config_crud.list(session, org_id, project_id, stage, type) @router.get( @@ -85,8 +53,8 @@ async def get_validator( session: SessionDep, _: AuthDep, ): - obj = get_validator_or_404(id, org_id, project_id, session) - return flatten_validator(obj) + obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + return validator_config_crud._flatten(obj) @router.patch( @@ -101,13 +69,12 @@ async def update_validator( session: SessionDep, _: AuthDep, ): - obj = get_validator_or_404(id, org_id, project_id, session) - updated_obj = update_validator_config( + obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + return validator_config_crud.update( + session, obj, payload.model_dump(exclude_unset=True), - session ) - return flatten_validator(updated_obj) @router.delete("/{id}") @@ -118,57 +85,6 @@ async def delete_validator( session: SessionDep, _: AuthDep, ): - obj = get_validator_or_404(id, org_id, project_id, session) - session.delete(obj) - session.commit() + obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + validator_config_crud.delete(session, obj) return {"success": True} - -def flatten_validator(row: ValidatorConfig) -> dict: - """ - Flatten validator config: combines base fields with config dict. - Returns a dict with all fields including config extras. - """ - base = row.model_dump(exclude={"config"}) - flattened = {**base, **(row.config or {})} - print("FLATTENED:", flattened) - return flattened - - -def get_validator_or_404( - id: UUID, - org_id: int, - project_id: int, - session: SessionDep, -) -> ValidatorConfig: - """Fetch validator by id, org_id, and project_id, or raise 404.""" - obj = session.query(ValidatorConfig).filter( - ValidatorConfig.id == id, - ValidatorConfig.org_id == org_id, - ValidatorConfig.project_id == project_id - ).first() - - if not obj: - raise HTTPException(status_code=404, detail="Validator not found") - - return obj - - -def update_validator_config( - obj: ValidatorConfig, - update_data: dict, - session: SessionDep, -) -> ValidatorConfig: - """Update validator config fields and return the updated object.""" - base, config = split_validator_payload(update_data) - - for k, v in base.items(): - setattr(obj, k, v) - - if config: - obj.config = {**(obj.config or {}), **config} - - session.add(obj) - session.commit() - session.refresh(obj) - - return obj diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py new file mode 100644 index 0000000..803dcf3 --- /dev/null +++ b/backend/app/crud/validator_config_crud.py @@ -0,0 +1,109 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select + +from app.models.config.validator_config_table import ValidatorConfig +from app.schemas.validator_config import Stage, ValidatorType +from app.utils import split_validator_payload + + +class ValidatorConfigCrud: + def create( + self, + session: Session, + org_id: int, + project_id: int, + payload + ): + data = payload.model_dump() + base, config = split_validator_payload(data) + + obj = ValidatorConfig( + org_id=org_id, + project_id=project_id, + config=config, + **base, + ) + + session.add(obj) + + try: + session.commit() + except IntegrityError: + session.rollback() + raise HTTPException( + 400, + "Validator already exists for this type and stage", + ) + + session.refresh(obj) + return self._flatten(obj) + + def list( + self, + session: Session, + org_id: int, + project_id: int, + stage: Optional[Stage] = None, + type: Optional[ValidatorType] = None, + ) -> List[dict]: + query = select(ValidatorConfig).where( + ValidatorConfig.org_id == org_id, + ValidatorConfig.project_id == project_id, + ) + + if stage: + query = query.where(ValidatorConfig.stage == stage) + + if type: + query = query.where(ValidatorConfig.type == type) + + rows = session.exec(query).all() + return [self._flatten(r) for r in rows] + + def get_or_404( + self, + session: Session, + id: UUID, + org_id: int, + project_id: int, + ) -> ValidatorConfig: + obj = session.get(ValidatorConfig, id) + + if not obj or obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(404, "Validator not found") + + return obj + + def update( + self, + session: Session, + obj: ValidatorConfig, + update_data: dict + ): + base, config = split_validator_payload(update_data) + + for k, v in base.items(): + setattr(obj, k, v) + + if config: + obj.config = {**(obj.config or {}), **config} + + session.commit() + session.refresh(obj) + + return self._flatten(obj) + + def delete(self, session: Session, obj: ValidatorConfig): + session.delete(obj) + session.commit() + + def _flatten(self, row: ValidatorConfig) -> dict: + base = row.model_dump(exclude={"config"}) + return {**base, **(row.config or {})} + + +validator_config_crud = ValidatorConfigCrud() diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 08c82ac..51a0739 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -1,14 +1,10 @@ import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from sqlmodel import Session, delete -from app.api.routes.validator_configs import ( - flatten_validator, - get_validator_or_404, - update_validator_config, -) +from app.crud.validator_config_crud import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType from app.core.db import engine from app.models.config.validator_config_table import ValidatorConfig @@ -55,42 +51,15 @@ def sample_validator(): ) -class TestFlattenValidator: - """Tests for flatten_validator helper function.""" +class TestFlatten: + def test_flatten_includes_config_fields(self, sample_validator): + result = validator_config_crud._flatten(sample_validator) - def test_flatten_validator_includes_config_fields(self, sample_validator): - """Test that flatten_validator includes config fields in output.""" - result = flatten_validator(sample_validator) - - assert result["id"] == TEST_VALIDATOR_ID - assert result["org_id"] == TEST_ORG_ID - assert result["project_id"] == TEST_PROJECT_ID - assert result["type"] == TEST_TYPE assert result["severity"] == "all" assert result["languages"] == ["en", "hi"] - - def test_flatten_validator_with_empty_config(self): - """Test flatten_validator with empty config dict.""" - validator = ValidatorConfig( - id=TEST_VALIDATOR_ID, - org_id=TEST_ORG_ID, - project_id=TEST_PROJECT_ID, - type=TEST_TYPE, - stage=TEST_STAGE, - on_fail_action=TEST_ON_FAIL, - is_enabled=True, - config={}, - ) - - result = flatten_validator(validator) - assert result["id"] == TEST_VALIDATOR_ID - assert "severity" not in result - # Base fields: id, org_id, project_id, type, stage, on_fail_action, is_enabled, created_at, updated_at - assert len(result) == 9 - def test_flatten_validator_with_none_config(self): - """Test flatten_validator with None config.""" + def test_flatten_empty_config(self): validator = ValidatorConfig( id=TEST_VALIDATOR_ID, org_id=TEST_ORG_ID, @@ -99,75 +68,82 @@ def test_flatten_validator_with_none_config(self): stage=TEST_STAGE, on_fail_action=TEST_ON_FAIL, is_enabled=True, - config=None, + config={}, ) - result = flatten_validator(validator) + result = validator_config_crud._flatten(validator) - assert result["id"] == TEST_VALIDATOR_ID assert "severity" not in result -class TestGetValidatorOr404: - """Tests for get_validator_or_404 helper function.""" +class TestGetOr404: + def test_success(self, sample_validator, mock_session): + mock_session.get.return_value = sample_validator - def test_get_validator_success(self, sample_validator, mock_session): - """Test successful validator retrieval.""" - mock_session.query.return_value.filter.return_value.first.return_value = ( - sample_validator - ) - - result = get_validator_or_404( - TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session + result = validator_config_crud.get_or_404( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORG_ID, + TEST_PROJECT_ID, ) assert result == sample_validator - mock_session.query.assert_called_once_with(ValidatorConfig) - - def test_get_validator_not_found(self, mock_session): - """Test validator not found raises 404.""" - mock_session.query.return_value.filter.return_value.first.return_value = None + mock_session.get.assert_called_once() - with pytest.raises(Exception) as exc_info: - get_validator_or_404(TEST_VALIDATOR_ID, TEST_ORG_ID, TEST_PROJECT_ID, mock_session) + def test_not_found(self, mock_session): + mock_session.get.return_value = None - assert "404" in str(exc_info.value) + with pytest.raises(Exception) as exc: + validator_config_crud.get_or_404( + mock_session, + TEST_VALIDATOR_ID, + TEST_ORG_ID, + TEST_PROJECT_ID, + ) + assert "Validator not found" in str(exc.value) -class TestUpdateValidatorConfig: - """Tests for update_validator_config helper function.""" - def test_update_validator_config_base_fields(self, sample_validator, mock_session): - """Test updating base validator fields.""" +class TestUpdate: + def test_update_base_fields(self, sample_validator, mock_session): update_data = { "type": ValidatorType.PIIRemover, "on_fail_action": GuardrailOnFail.Exception, } - result = update_validator_config(sample_validator, update_data, mock_session) + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) + + assert result["type"] == ValidatorType.PIIRemover + assert result["on_fail_action"] == GuardrailOnFail.Exception - assert result.type == ValidatorType.PIIRemover - assert result.on_fail_action == GuardrailOnFail.Exception mock_session.commit.assert_called_once() mock_session.refresh.assert_called_once() - def test_update_validator_config_extra_fields(self, sample_validator, mock_session): - """Test updating extra config fields.""" + def test_update_extra_fields(self, sample_validator, mock_session): update_data = {"severity": "high", "new_field": "new_value"} - result = update_validator_config(sample_validator, update_data, mock_session) + result = validator_config_crud.update( + mock_session, + sample_validator, + update_data, + ) - assert result.config["severity"] == "high" - assert result.config["new_field"] == "new_value" - assert result.config["languages"] == ["en", "hi"] # Original values preserved + assert result["severity"] == "high" + assert result["new_field"] == "new_value" + assert result["languages"] == ["en", "hi"] - def test_update_validator_merges_config(self, sample_validator, mock_session): - """Test that updating config merges with existing config.""" + def test_merge_config(self, sample_validator, mock_session): sample_validator.config = {"severity": "all", "languages": ["en"]} - update_data = {"languages": ["en", "hi", "mr"]} - - result = update_validator_config(sample_validator, update_data, mock_session) - assert result.config["languages"] == ["en", "hi", "mr"] - assert result.config["severity"] == "all" + result = validator_config_crud.update( + mock_session, + sample_validator, + {"languages": ["en", "hi"]}, + ) + assert result["languages"] == ["en", "hi"] + assert result["severity"] == "all" From 0d2d609efb2a92a1512bbcfcc88ff065c958fb5e Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:05:07 +0530 Subject: [PATCH 06/10] removed .env.test --- .env.test | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 .env.test diff --git a/.env.test b/.env.test deleted file mode 100644 index 40e9ad8..0000000 --- a/.env.test +++ /dev/null @@ -1,31 +0,0 @@ -DOMAIN=localhost - -ENVIRONMENT=testing - -PROJECT_NAME="Kaapi-Guardrails" -STACK_NAME=Kaapi-Guardrails - -# API Base URL for cron scripts (defaults to http://localhost:8000 if not set) -API_BASE_URL=http://localhost:8000 - -# Postgres -POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails -POSTGRES_USER=postgres -POSTGRES_PASSWORD=postgres - -SENTRY_DSN= - -# Configure these with your own Docker registry images - -DOCKER_IMAGE_BACKEND=kaapi-guardrails-backend - -# Callback Timeouts (in seconds) -CALLBACK_CONNECT_TIMEOUT=3 -CALLBACK_READ_TIMEOUT=10 - -# require as a env if you want to use doc transformation -OPENAI_API_KEY="" -GUARDRAILS_HUB_API_KEY="" -AUTH_TOKEN="" \ No newline at end of file From d445ad96d06be21fae3bc30268c97d3c9c8a09e3 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:17:32 +0530 Subject: [PATCH 07/10] resolved comments --- .../versions/003_added_validator_config.py | 2 +- backend/app/api/routes/validator_configs.py | 9 ++++---- backend/app/crud/validator_config_crud.py | 21 ++++++++++--------- .../models/config/validator_config_table.py | 7 +++++-- backend/app/tests/conftest.py | 2 +- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py index 2ee47d3..29a5a06 100644 --- a/backend/app/alembic/versions/003_added_validator_config.py +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -1,7 +1,7 @@ """Added validator_config table Revision ID: 003 -Revises: 001 +Revises: 002 Create Date: 2026-02-05 09:42:54.128852 """ diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 2829a01..755cf82 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -1,10 +1,11 @@ -from typing import List, Optional +from typing import Optional from uuid import UUID from fastapi import APIRouter from app.api.deps import AuthDep, SessionDep -from app.schemas.validator_config import * +from app.core.enum import Stage, ValidatorType +from app.schemas.validator_config import ValidatorCreate, ValidatorResponse, ValidatorUpdate from app.crud.validator_config_crud import validator_config_crud @@ -29,7 +30,7 @@ async def create_validator( @router.get( "/", - response_model=List[ValidatorResponse] + response_model=list[ValidatorResponse] ) async def list_validators( org_id: int, @@ -54,7 +55,7 @@ async def get_validator( _: AuthDep, ): obj = validator_config_crud.get_or_404(session, id, org_id, project_id) - return validator_config_crud._flatten(obj) + return validator_config_crud.flatten(obj) @router.patch( diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py index 803dcf3..73afd33 100644 --- a/backend/app/crud/validator_config_crud.py +++ b/backend/app/crud/validator_config_crud.py @@ -1,13 +1,14 @@ -from typing import List, Optional +from typing import Optional from uuid import UUID from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select +from app.core.enum import Stage, ValidatorType from app.models.config.validator_config_table import ValidatorConfig -from app.schemas.validator_config import Stage, ValidatorType -from app.utils import split_validator_payload +from app.schemas.validator_config import ValidatorCreate +from app.utils import now, split_validator_payload class ValidatorConfigCrud: @@ -16,7 +17,7 @@ def create( session: Session, org_id: int, project_id: int, - payload + payload: ValidatorCreate ): data = payload.model_dump() base, config = split_validator_payload(data) @@ -40,7 +41,7 @@ def create( ) session.refresh(obj) - return self._flatten(obj) + return self.flatten(obj) def list( self, @@ -49,7 +50,7 @@ def list( project_id: int, stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, - ) -> List[dict]: + ) -> list[dict]: query = select(ValidatorConfig).where( ValidatorConfig.org_id == org_id, ValidatorConfig.project_id == project_id, @@ -62,7 +63,7 @@ def list( query = query.where(ValidatorConfig.type == type) rows = session.exec(query).all() - return [self._flatten(r) for r in rows] + return [self.flatten(r) for r in rows] def get_or_404( self, @@ -91,17 +92,17 @@ def update( if config: obj.config = {**(obj.config or {}), **config} - + obj.updated_at = now() session.commit() session.refresh(obj) - return self._flatten(obj) + return self.flatten(obj) def delete(self, session: Session, obj: ValidatorConfig): session.delete(obj) session.commit() - def _flatten(self, row: ValidatorConfig) -> dict: + def flatten(self, row: ValidatorConfig) -> dict: base = row.model_dump(exclude={"config"}) return {**base, **(row.config or {})} diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py index 68d57cf..0f33590 100644 --- a/backend/app/models/config/validator_config_table.py +++ b/backend/app/models/config/validator_config_table.py @@ -71,7 +71,10 @@ class ValidatorConfig(SQLModel, table=True): updated_at: datetime = Field( default_factory=now, nullable=False, - sa_column_kwargs={"comment": "Timestamp when the validator config was last updated"}, + sa_column_kwargs={ + "comment": "Timestamp when the validator config was last updated", + "onupdate": now, + }, ) __table_args__ = ( @@ -79,4 +82,4 @@ class ValidatorConfig(SQLModel, table=True): "org_id", "project_id", "type", "stage", name="uq_validator_identity" ), - ) \ No newline at end of file + ) diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 2ff34d0..d595b97 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -30,7 +30,7 @@ def setup_test_db(): def clean_db(): with Session(test_engine) as session: for table in reversed(SQLModel.metadata.sorted_tables): - session.exec(table.delete()) + session.execute(table.delete()) session.commit() @pytest.fixture(scope="function", autouse=True) From 59deffe0dc11cdb46f11a4392035d0feaba84eee Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:50:10 +0530 Subject: [PATCH 08/10] resolved comments --- backend/app/tests/test_validator_configs.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 51a0739..e014b5c 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock import pytest -from sqlmodel import Session, delete +from sqlmodel import Session from app.crud.validator_config_crud import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType @@ -17,19 +17,6 @@ TEST_STAGE = Stage.Input TEST_ON_FAIL = GuardrailOnFail.Fix - -@pytest.fixture -def clear_database(): - """Clear ValidatorConfig table before and after each test.""" - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - yield - with Session(engine) as session: - session.exec(delete(ValidatorConfig)) - session.commit() - - @pytest.fixture def mock_session(): """Create a mock session for database operations.""" @@ -53,7 +40,7 @@ def sample_validator(): class TestFlatten: def test_flatten_includes_config_fields(self, sample_validator): - result = validator_config_crud._flatten(sample_validator) + result = validator_config_crud.flatten(sample_validator) assert result["severity"] == "all" assert result["languages"] == ["en", "hi"] @@ -71,7 +58,7 @@ def test_flatten_empty_config(self): config={}, ) - result = validator_config_crud._flatten(validator) + result = validator_config_crud.flatten(validator) assert "severity" not in result From 97dae6c09d5ae75823711de641f673270eb38395 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Thu, 5 Feb 2026 17:57:11 +0530 Subject: [PATCH 09/10] resolved comments --- .env.test.example | 2 +- backend/app/crud/validator_config_crud.py | 3 ++- backend/app/tests/test_validator_configs.py | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.env.test.example b/.env.test.example index 40e9ad8..b7e4927 100644 --- a/.env.test.example +++ b/.env.test.example @@ -11,7 +11,7 @@ API_BASE_URL=http://localhost:8000 # Postgres POSTGRES_SERVER=localhost POSTGRES_PORT=5432 -POSTGRES_DB=kaapi-guardrails +POSTGRES_DB=kaapi_guardrails_testing POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py index 73afd33..8663fdd 100644 --- a/backend/app/crud/validator_config_crud.py +++ b/backend/app/crud/validator_config_crud.py @@ -92,7 +92,8 @@ def update( if config: obj.config = {**(obj.config or {}), **config} - obj.updated_at = now() + + obj.updated_at = now() session.commit() session.refresh(obj) diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index e014b5c..0d26ec8 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -6,7 +6,6 @@ from app.crud.validator_config_crud import validator_config_crud from app.core.enum import GuardrailOnFail, Stage, ValidatorType -from app.core.db import engine from app.models.config.validator_config_table import ValidatorConfig # Test data constants From 29431e0de8b5798089718b8625192ac330870c44 Mon Sep 17 00:00:00 2001 From: rkritika1508 Date: Fri, 6 Feb 2026 17:03:42 +0530 Subject: [PATCH 10/10] resolved comments --- .env.test.example | 2 +- .../alembic/versions/001_added_request_log.py | 2 +- .../versions/002_added_validator_log.py | 2 +- .../versions/003_added_validator_config.py | 8 +++--- backend/app/api/routes/validator_configs.py | 21 +++++++-------- backend/app/core/constants.py | 2 +- backend/app/crud/__init__.py | 2 +- backend/app/crud/validator_config_crud.py | 26 +++++++++---------- .../models/config/validator_config_table.py | 4 +-- backend/app/schemas/validator_config.py | 5 +++- backend/app/tests/test_validator_configs.py | 10 +++---- .../test_validator_configs_integration.py | 13 +++++----- backend/app/utils.py | 14 +++++----- 13 files changed, 56 insertions(+), 55 deletions(-) diff --git a/.env.test.example b/.env.test.example index b7e4927..368b7d7 100644 --- a/.env.test.example +++ b/.env.test.example @@ -10,8 +10,8 @@ API_BASE_URL=http://localhost:8000 # Postgres POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 POSTGRES_DB=kaapi_guardrails_testing +POSTGRES_PORT=5432 POSTGRES_USER=postgres POSTGRES_PASSWORD=postgres diff --git a/backend/app/alembic/versions/001_added_request_log.py b/backend/app/alembic/versions/001_added_request_log.py index 706d504..5f2e404 100644 --- a/backend/app/alembic/versions/001_added_request_log.py +++ b/backend/app/alembic/versions/001_added_request_log.py @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision: str = '001' -down_revision: Union[str, Sequence[str], None] = None +down_revision: str | None = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/app/alembic/versions/002_added_validator_log.py b/backend/app/alembic/versions/002_added_validator_log.py index d46513f..6b50229 100644 --- a/backend/app/alembic/versions/002_added_validator_log.py +++ b/backend/app/alembic/versions/002_added_validator_log.py @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision: str = '002' -down_revision: Union[str, Sequence[str], None] = '001' +down_revision: str = '001' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/app/alembic/versions/003_added_validator_config.py b/backend/app/alembic/versions/003_added_validator_config.py index 29a5a06..72889b0 100644 --- a/backend/app/alembic/versions/003_added_validator_config.py +++ b/backend/app/alembic/versions/003_added_validator_config.py @@ -13,7 +13,7 @@ # revision identifiers, used by Alembic. revision: str = '003' -down_revision: Union[str, Sequence[str], None] = "002" +down_revision: str = '002' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,7 +21,7 @@ def upgrade() -> None: op.create_table('validator_config', sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('org_id', sa.Integer(), nullable=False), + sa.Column('organization_id', sa.Integer(), nullable=False), sa.Column('project_id', sa.Integer(), nullable=False), sa.Column('type', sa.String(), nullable=False), sa.Column('stage', sa.String(), nullable=False), @@ -37,10 +37,10 @@ def upgrade() -> None: sa.Column('updated_at', sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('org_id', 'project_id', 'type', 'stage', name='uq_validator_identity') + sa.UniqueConstraint('organization_id', 'project_id', 'type', 'stage', name='uq_validator_identity') ) - op.create_index("idx_validator_org", "validator_config", ["org_id"]) + op.create_index("idx_validator_organization", "validator_config", ["organization_id"]) op.create_index("idx_validator_project", "validator_config", ["project_id"]) op.create_index("idx_validator_type", "validator_config", ["type"]) op.create_index("idx_validator_stage", "validator_config", ["stage"]) diff --git a/backend/app/api/routes/validator_configs.py b/backend/app/api/routes/validator_configs.py index 755cf82..feedad9 100644 --- a/backend/app/api/routes/validator_configs.py +++ b/backend/app/api/routes/validator_configs.py @@ -22,25 +22,24 @@ async def create_validator( payload: ValidatorCreate, session: SessionDep, - org_id: int, + organization_id: int, project_id: int, _: AuthDep, ): - return validator_config_crud.create(session, org_id, project_id, payload) - + return validator_config_crud.create(session, organization_id, project_id, payload) @router.get( "/", response_model=list[ValidatorResponse] ) async def list_validators( - org_id: int, + organization_id: int, project_id: int, session: SessionDep, _: AuthDep, stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, ): - return validator_config_crud.list(session, org_id, project_id, stage, type) + return validator_config_crud.list(session, organization_id, project_id, stage, type) @router.get( @@ -49,12 +48,12 @@ async def list_validators( ) async def get_validator( id: UUID, - org_id: int, + organization_id: int, project_id: int, session: SessionDep, _: AuthDep, ): - obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + obj = validator_config_crud.get_or_404(session, id, organization_id, project_id) return validator_config_crud.flatten(obj) @@ -64,13 +63,13 @@ async def get_validator( ) async def update_validator( id: UUID, - org_id: int, + organization_id: int, project_id: int, payload: ValidatorUpdate, session: SessionDep, _: AuthDep, ): - obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + obj = validator_config_crud.get_or_404(session, id, organization_id, project_id) return validator_config_crud.update( session, obj, @@ -81,11 +80,11 @@ async def update_validator( @router.delete("/{id}") async def delete_validator( id: UUID, - org_id: int, + organization_id: int, project_id: int, session: SessionDep, _: AuthDep, ): - obj = validator_config_crud.get_or_404(session, id, org_id, project_id) + obj = validator_config_crud.get_or_404(session, id, organization_id, project_id) validator_config_crud.delete(session, obj) return {"success": True} diff --git a/backend/app/core/constants.py b/backend/app/core/constants.py index 115ad21..6c3825d 100644 --- a/backend/app/core/constants.py +++ b/backend/app/core/constants.py @@ -8,7 +8,7 @@ REPHRASE_ON_FAIL_PREFIX = "Please rephrase the query without unsafe content." VALIDATOR_CONFIG_SYSTEM_FIELDS = { - "org_id", + "organization_id", "project_id", "type", "stage", diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 37bc215..a6ddc7b 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -1 +1 @@ -from app.crud.request_log_repo import RequestLogCrud \ No newline at end of file +from app.crud.request_log_repo import RequestLogCrud diff --git a/backend/app/crud/validator_config_crud.py b/backend/app/crud/validator_config_crud.py index 8663fdd..8b7b4ad 100644 --- a/backend/app/crud/validator_config_crud.py +++ b/backend/app/crud/validator_config_crud.py @@ -15,18 +15,18 @@ class ValidatorConfigCrud: def create( self, session: Session, - org_id: int, + organization_id: int, project_id: int, payload: ValidatorCreate ): data = payload.model_dump() - base, config = split_validator_payload(data) + model_fields, config_fields = split_validator_payload(data) obj = ValidatorConfig( - org_id=org_id, + organization_id=organization_id, project_id=project_id, - config=config, - **base, + config=config_fields, + **model_fields, ) session.add(obj) @@ -46,13 +46,13 @@ def create( def list( self, session: Session, - org_id: int, + organization_id: int, project_id: int, stage: Optional[Stage] = None, type: Optional[ValidatorType] = None, ) -> list[dict]: query = select(ValidatorConfig).where( - ValidatorConfig.org_id == org_id, + ValidatorConfig.organization_id == organization_id, ValidatorConfig.project_id == project_id, ) @@ -69,12 +69,12 @@ def get_or_404( self, session: Session, id: UUID, - org_id: int, + organization_id: int, project_id: int, ) -> ValidatorConfig: obj = session.get(ValidatorConfig, id) - if not obj or obj.org_id != org_id or obj.project_id != project_id: + if not obj or obj.organization_id != organization_id or obj.project_id != project_id: raise HTTPException(404, "Validator not found") return obj @@ -85,13 +85,13 @@ def update( obj: ValidatorConfig, update_data: dict ): - base, config = split_validator_payload(update_data) + model_fields, config_fields = split_validator_payload(update_data) - for k, v in base.items(): + for k, v in model_fields.items(): setattr(obj, k, v) - if config: - obj.config = {**(obj.config or {}), **config} + if config_fields: + obj.config = {**(obj.config or {}), **config_fields} obj.updated_at = now() session.commit() diff --git a/backend/app/models/config/validator_config_table.py b/backend/app/models/config/validator_config_table.py index 0f33590..5c4a7c1 100644 --- a/backend/app/models/config/validator_config_table.py +++ b/backend/app/models/config/validator_config_table.py @@ -19,7 +19,7 @@ class ValidatorConfig(SQLModel, table=True): sa_column_kwargs={"comment": "Unique identifier for the validator configuration"}, ) - org_id: int = Field( + organization_id: int = Field( index=True, sa_column_kwargs={"comment": "Identifier for the organization"}, ) @@ -79,7 +79,7 @@ class ValidatorConfig(SQLModel, table=True): __table_args__ = ( UniqueConstraint( - "org_id", "project_id", "type", "stage", + "organization_id", "project_id", "type", "stage", name="uq_validator_identity" ), ) diff --git a/backend/app/schemas/validator_config.py b/backend/app/schemas/validator_config.py index c8a2ae1..7e693b5 100644 --- a/backend/app/schemas/validator_config.py +++ b/backend/app/schemas/validator_config.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Optional from uuid import UUID @@ -31,5 +32,7 @@ class ValidatorUpdate(SQLModel): class ValidatorResponse(ValidatorBase): id: UUID - org_id: int + organization_id: int project_id: int + created_at: datetime + updated_at: datetime diff --git a/backend/app/tests/test_validator_configs.py b/backend/app/tests/test_validator_configs.py index 0d26ec8..fc3c673 100644 --- a/backend/app/tests/test_validator_configs.py +++ b/backend/app/tests/test_validator_configs.py @@ -9,7 +9,7 @@ from app.models.config.validator_config_table import ValidatorConfig # Test data constants -TEST_ORG_ID = 1 +TEST_ORGANIZATION_ID = 1 TEST_PROJECT_ID = 1 TEST_VALIDATOR_ID = uuid.uuid4() TEST_TYPE = ValidatorType.LexicalSlur @@ -27,7 +27,7 @@ def sample_validator(): """Create a sample validator config for testing.""" return ValidatorConfig( id=TEST_VALIDATOR_ID, - org_id=TEST_ORG_ID, + organization_id=TEST_ORGANIZATION_ID, project_id=TEST_PROJECT_ID, type=TEST_TYPE, stage=TEST_STAGE, @@ -48,7 +48,7 @@ def test_flatten_includes_config_fields(self, sample_validator): def test_flatten_empty_config(self): validator = ValidatorConfig( id=TEST_VALIDATOR_ID, - org_id=TEST_ORG_ID, + organization_id=TEST_ORGANIZATION_ID, project_id=TEST_PROJECT_ID, type=TEST_TYPE, stage=TEST_STAGE, @@ -69,7 +69,7 @@ def test_success(self, sample_validator, mock_session): result = validator_config_crud.get_or_404( mock_session, TEST_VALIDATOR_ID, - TEST_ORG_ID, + TEST_ORGANIZATION_ID, TEST_PROJECT_ID, ) @@ -83,7 +83,7 @@ def test_not_found(self, mock_session): validator_config_crud.get_or_404( mock_session, TEST_VALIDATOR_ID, - TEST_ORG_ID, + TEST_ORGANIZATION_ID, TEST_PROJECT_ID, ) diff --git a/backend/app/tests/test_validator_configs_integration.py b/backend/app/tests/test_validator_configs_integration.py index effffd5..6965796 100644 --- a/backend/app/tests/test_validator_configs_integration.py +++ b/backend/app/tests/test_validator_configs_integration.py @@ -1,7 +1,6 @@ import uuid import pytest -from sqlalchemy.exc import OperationalError from sqlmodel import Session, delete from app.core.db import engine @@ -10,10 +9,10 @@ pytestmark = pytest.mark.integration # Test data constants -TEST_ORG_ID = 1 +TEST_ORGANIZATION_ID = 1 TEST_PROJECT_ID = 1 BASE_URL = "/api/v1/guardrails/validators/configs/" -DEFAULT_QUERY_PARAMS = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" +DEFAULT_QUERY_PARAMS = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" VALIDATOR_PAYLOADS = { "lexical_slur": { @@ -69,7 +68,7 @@ def get_validator(self, client, validator_id): def list_validators(self, client, **query_params): """Helper to list validators with optional filters.""" - params_str = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + params_str = f"?organization_id={TEST_ORGANIZATION_ID}&project_id={TEST_PROJECT_ID}" if query_params: params_str += "&" + "&".join(f"{k}={v}" for k, v in query_params.items()) return client.get(f"{BASE_URL}{params_str}") @@ -160,7 +159,7 @@ def test_list_validators_filter_by_type(self, integration_client, clear_database def test_list_validators_empty(self, integration_client, clear_database): """Test listing validators when none exist.""" response = integration_client.get( - f"{BASE_URL}?org_id=999&project_id=999", + f"{BASE_URL}?organization_id=999&project_id=999", ) assert response.status_code == 200 @@ -202,7 +201,7 @@ def test_get_validator_wrong_org(self, integration_client, clear_database): # Try to access it as different org response = integration_client.get( - f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", ) assert response.status_code == 404 @@ -292,7 +291,7 @@ def test_delete_validator_wrong_org(self, integration_client, clear_database): # Try to delete it as different org response = integration_client.delete( - f"{BASE_URL}{validator_id}/?org_id=2&project_id=1", + f"{BASE_URL}{validator_id}/?organization_id=2&project_id=1", ) assert response.status_code == 404 diff --git a/backend/app/utils.py b/backend/app/utils.py index 30543d5..4df593b 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -14,16 +14,16 @@ def now(): return datetime.now(timezone.utc).replace(tzinfo=None) def split_validator_payload(data: dict): - base = {} - config = {} + model_fields = {} + config_fields = {} - for k, v in data.items(): - if k in SYSTEM_FIELDS: - base[k] = v + for key, value in data.items(): + if key in SYSTEM_FIELDS: + model_fields[key] = value else: - config[k] = v + config_fields[key] = value - return base, config + return model_fields, config_fields class APIResponse(BaseModel, Generic[T]):