diff --git a/backend/app/alembic/versions/004_added_ban_list_config.py b/backend/app/alembic/versions/004_added_ban_list_config.py new file mode 100644 index 0000000..c024911 --- /dev/null +++ b/backend/app/alembic/versions/004_added_ban_list_config.py @@ -0,0 +1,39 @@ +"""Added ban_list table + +Revision ID: 004 +Revises: 003 +Create Date: 2026-02-05 09:42:54.128852 + +""" +from typing import Sequence, Union + +from alembic import op +from sqlalchemy.dialects import postgresql +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = '004' +down_revision: Union[str, Sequence[str], None] = "003" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table('ban_list', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=False), + sa.Column('org_id', sa.Integer(), nullable=False), + sa.Column('project_id', sa.Integer(), nullable=False), + sa.Column('domain', sa.String(), nullable=False), + sa.Column('is_public', sa.Boolean(), nullable=False, server_default=sa.false()), + sa.Column("banned_words", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + + sa.PrimaryKeyConstraint('id'), + ) + + +def downgrade() -> None: + op.drop_table('validator_config') diff --git a/backend/app/api/main.py b/backend/app/api/main.py index bf78ade..cac0cb1 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from app.api.routes import utils, guardrails, validator_configs +from app.api.routes import ban_list_configs, guardrails, utils, validator_configs api_router = APIRouter() api_router.include_router(utils.router) api_router.include_router(guardrails.router) api_router.include_router(validator_configs.router) +api_router.include_router(ban_list_configs.router) # if settings.ENVIRONMENT == "local": # api_router.include_router(private.router) diff --git a/backend/app/api/routes/ban_list_configs.py b/backend/app/api/routes/ban_list_configs.py new file mode 100644 index 0000000..9d8b0a0 --- /dev/null +++ b/backend/app/api/routes/ban_list_configs.py @@ -0,0 +1,118 @@ +from typing import List, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from sqlmodel import Session + +from app.api.deps import AuthDep, SessionDep +from app.crud.ban_list_crud import ban_list_crud +from app.schemas.ban_list_config import ( + BanListCreate, + BanListUpdate, + BanListResponse +) + +router = APIRouter( + prefix="/guardrails/ban-lists", + tags=["Ban Lists"] +) + + +def check_owner(obj, org_id, project_id): + if obj.org_id != org_id or obj.project_id != project_id: + raise HTTPException(status_code=403, detail="Not owner") + + +@router.post( + "/", + response_model=BanListResponse + ) +def create_ban_list( + payload: BanListCreate, + session: SessionDep, + org_id: int, + project_id: int, + _: AuthDep, +): + return ban_list_crud.create( + session, + data=payload, + org_id=org_id, + project_id=project_id, + ) + + +@router.get( + "/", + response_model=list[BanListResponse] + ) +def list_ban_lists( + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, + domain: Optional[str] = None, +): + return ban_list_crud.list( + session, + org_id=org_id, + project_id=project_id, + domain=domain, + ) + + +@router.get( + "/{id}", + response_model=BanListResponse + ) +def get_ban_list( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = ban_list_crud.get(session, id) + if not obj: + raise HTTPException(404) + + if not obj.is_public: + check_owner(obj, org_id, project_id) + return obj + + +@router.patch( + "/{id}", + response_model=BanListResponse + ) +def update_ban_list( + id: UUID, + org_id: int, + project_id: int, + payload: BanListUpdate, + session: SessionDep, + _: AuthDep, +): + obj = ban_list_crud.get(session, id) + if not obj: + raise HTTPException(404) + + check_owner(obj, org_id, project_id) + return ban_list_crud.update(session, obj=obj, data=payload) + + +@router.delete("/{id}") +def delete_ban_list( + id: UUID, + org_id: int, + project_id: int, + session: SessionDep, + _: AuthDep, +): + obj = ban_list_crud.get(session, id) + if not obj: + raise HTTPException(404) + + check_owner(obj, org_id, project_id) + ban_list_crud.delete(session, obj) + return {"success": True} diff --git a/backend/app/crud/ban_list_crud.py b/backend/app/crud/ban_list_crud.py new file mode 100644 index 0000000..215012d --- /dev/null +++ b/backend/app/crud/ban_list_crud.py @@ -0,0 +1,79 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from sqlmodel import Session, select + +from app.models.config.ban_list_table import BanList +from app.schemas.ban_list_config import BanListCreate, BanListUpdate +from app.utils import now + +class BanListCrud: + + def create( + self, + db: Session, + *, + data: BanListCreate, + org_id: int, + project_id: int, + ) -> BanList: + obj = BanList( + **data.model_dump(), + org_id=org_id, + project_id=project_id, + ) + db.add(obj) + db.commit() + db.refresh(obj) + return obj + + def get(self, db: Session, id: UUID) -> Optional[BanList]: + return db.get(BanList, id) + + def list( + self, + db: Session, + *, + org_id: int, + project_id: int, + domain: Optional[str] = None, + ) -> List[BanList]: + stmt = select(BanList).where( + ( + (BanList.org_id == org_id) & + (BanList.project_id == project_id) + ) | + (BanList.is_public == True) + ) + + if domain: + stmt = stmt.where(BanList.domain == domain) + + return list(db.exec(stmt)) + + def update( + self, + db: Session, + *, + obj: BanList, + data: BanListUpdate, + ) -> BanList: + update_data = data.model_dump(exclude_unset=True) + + for k, v in update_data.items(): + setattr(obj, k, v) + + obj.updated_at = now() + + db.add(obj) + db.commit() + db.refresh(obj) + return obj + + def delete(self, db: Session, obj: BanList): + db.delete(obj) + db.commit() + + +ban_list_crud = BanListCrud() diff --git a/backend/app/models/config/ban_list_table.py b/backend/app/models/config/ban_list_table.py new file mode 100644 index 0000000..b34b351 --- /dev/null +++ b/backend/app/models/config/ban_list_table.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import List, Optional +from uuid import UUID, uuid4 + +from sqlalchemy import Column, String +from sqlalchemy.dialects.postgresql import ARRAY +from sqlmodel import Field, SQLModel + +from app.utils import now + +class BanList(SQLModel, table=True): + __tablename__ = "ban_list" + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + index=True, + sa_column_kwargs={"comment": "Unique identifier for the ban list entry"} + ) + + name: str = Field( + nullable=False, + sa_column_kwargs={"comment": "Name of the ban list entry"} + ) + + description: Optional[str] = Field( + nullable=False, + sa_column_kwargs={"comment": "Description of the ban list entry"} + ) + + banned_words: list[str] = Field( + default_factory=list, + sa_column=Column( + ARRAY(String), + nullable=False, + comment="List of banned words", + ), + description=("List of banned words") + ) + + org_id: int = Field( + index=True, + nullable=False, + sa_column_kwargs={"comment": "Identifier for the organization"}, + ) + + project_id: int = Field( + index=True, + nullable=False, + sa_column_kwargs={"comment": "Identifier for the project"}, + ) + + domain: str = Field( + default=None, + index=False, + nullable=False, + sa_column_kwargs={"comment": "Domain or context for the ban list entry"} + ) + + is_public: bool = Field( + default=False, + sa_column_kwargs={"comment": "Whether the ban list entry is public or private"} + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was created"} + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the ban list entry was last updated"} + ) diff --git a/backend/app/schemas/ban_list_config.py b/backend/app/schemas/ban_list_config.py new file mode 100644 index 0000000..8b39d44 --- /dev/null +++ b/backend/app/schemas/ban_list_config.py @@ -0,0 +1,33 @@ +from uuid import UUID +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, Field +from sqlmodel import SQLModel + +class BanListBase(SQLModel): + name: str + description: str + banned_words: list[str] + domain: str + is_public: bool = False + + +class BanListCreate(BanListBase): + pass + + +class BanListUpdate(SQLModel): + name: Optional[str] = None + description: Optional[str] = None + banned_words: Optional[list[str]] = None + domain: Optional[str] = None + is_public: Optional[bool] = None + + +class BanListResponse(BanListBase): + id: UUID + org_id: int + project_id: int + created_at: datetime + updated_at: datetime diff --git a/backend/app/tests/test_ban_list_configs.py b/backend/app/tests/test_ban_list_configs.py new file mode 100644 index 0000000..da4b5ef --- /dev/null +++ b/backend/app/tests/test_ban_list_configs.py @@ -0,0 +1,200 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from sqlmodel import Session + +from app.api.routes.ban_list_configs import ( + create_ban_list, + list_ban_lists, + get_ban_list, + update_ban_list, + delete_ban_list, +) +from app.schemas.ban_list_config import ( + BanListCreate, + BanListUpdate, +) + +TEST_ID = uuid.uuid4() +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 10 + +@pytest.fixture +def mock_session(): + return MagicMock(spec=Session) + + +@pytest.fixture +def sample_ban_list(): + obj = MagicMock() + obj.id = TEST_ID + obj.name = "test" + obj.description = "desc" + obj.banned_words = ["bad"] + obj.org_id = TEST_ORG_ID + obj.project_id = TEST_PROJECT_ID + obj.domain = "health" + obj.is_public = False + return obj + + +@pytest.fixture +def create_payload(): + return BanListCreate( + name="test", + description="desc", + banned_words=["bad"], + domain="health", + is_public=False, + ) + + +def test_create_calls_crud(mock_session, create_payload, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.create.return_value = sample_ban_list + + result = create_ban_list( + payload=create_payload, + session=mock_session, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + _=None, + ) + + crud.create.assert_called_once() + assert result == sample_ban_list + + +def test_list_returns_data(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.list.return_value = [sample_ban_list] + + result = list_ban_lists( + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert len(result) == 1 + crud.list.assert_called_once() + + +def test_get_success(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + result = get_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert result == sample_ban_list + + +def test_get_not_found(mock_session): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = None + + with pytest.raises(HTTPException) as exc: + get_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert exc.value.status_code == 404 + + +def test_get_forbidden(mock_session, sample_ban_list): + sample_ban_list.org_id = 999 # different owner + + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + with pytest.raises(HTTPException) as exc: + get_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + assert exc.value.status_code == 403 + + +def test_update_success(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + crud.update.return_value = sample_ban_list + + result = update_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + payload=BanListUpdate(name="new"), + session=mock_session, + _=None, + ) + + crud.update.assert_called_once() + assert result == sample_ban_list + + +def test_update_forbidden(mock_session, sample_ban_list): + sample_ban_list.org_id = 999 + + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + with pytest.raises(HTTPException) as exc: + update_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + payload=BanListUpdate(name="new"), + session=mock_session, + _=None, + ) + + assert exc.value.status_code == 403 + + +def test_delete_success(mock_session, sample_ban_list): + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + result = delete_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) + + crud.delete.assert_called_once() + assert result["success"] is True + + +def test_delete_forbidden(mock_session, sample_ban_list): + sample_ban_list.org_id = 999 + + with patch("app.api.routes.ban_list_configs.ban_list_crud") as crud: + crud.get.return_value = sample_ban_list + + with pytest.raises(HTTPException): + delete_ban_list( + id=TEST_ID, + org_id=TEST_ORG_ID, + project_id=TEST_PROJECT_ID, + session=mock_session, + _=None, + ) diff --git a/backend/app/tests/test_ban_lists_integration.py b/backend/app/tests/test_ban_lists_integration.py new file mode 100644 index 0000000..956e932 --- /dev/null +++ b/backend/app/tests/test_ban_lists_integration.py @@ -0,0 +1,221 @@ +import uuid +import pytest +from sqlmodel import Session, delete + +from app.core.db import engine +from app.models.config.ban_list_table import BanList + +pytestmark = pytest.mark.integration + + +# Test data constants +TEST_ORG_ID = 1 +TEST_PROJECT_ID = 1 +BASE_URL = "/api/v1/guardrails/ban-lists/" +DEFAULT_QUERY = f"?org_id={TEST_ORG_ID}&project_id={TEST_PROJECT_ID}" + + +BAN_LIST_PAYLOADS = { + "minimal": { + "name": "default", + "description": "basic list", + "banned_words": ["bad"], + "domain": "general", + }, + "health": { + "name": "health-list", + "description": "healthcare words", + "banned_words": ["gender detection", "sonography"], + "domain": "health", + }, + "edu": { + "name": "edu-list", + "description": "education words", + "banned_words": ["cheating"], + "domain": "edu", + }, + "public": { + "name": "public-list", + "description": "shared", + "banned_words": ["shared"], + "is_public": True, + "domain": "general", + }, +} + + +@pytest.fixture +def clear_database(): + with Session(engine) as session: + session.exec(delete(BanList)) + session.commit() + + yield + + with Session(engine) as session: + session.exec(delete(BanList)) + session.commit() + + +class BaseBanListTest: + + def create(self, client, payload_key="minimal", **kwargs): + payload = {**BAN_LIST_PAYLOADS[payload_key], **kwargs} + return client.post(f"{BASE_URL}{DEFAULT_QUERY}", json=payload) + + def list(self, client, **filters): + params = DEFAULT_QUERY + if filters: + params += "&" + "&".join(f"{k}={v}" for k, v in filters.items()) + return client.get(f"{BASE_URL}{params}") + + def get(self, client, id, org=TEST_ORG_ID, project=TEST_PROJECT_ID): + return client.get(f"{BASE_URL}{id}/?org_id={org}&project_id={project}") + + def update(self, client, id, payload): + return client.patch(f"{BASE_URL}{id}/{DEFAULT_QUERY}", json=payload) + + def delete(self, client, id): + return client.delete(f"{BASE_URL}{id}/{DEFAULT_QUERY}") + + +class TestCreateBanList(BaseBanListTest): + + def test_create_success(self, integration_client, clear_database): + response = self.create(integration_client, "minimal") + + assert response.status_code == 200 + data = response.json() + + assert data["name"] == "default" + assert data["banned_words"] == ["bad"] + assert "id" in data + + def test_create_validation_error(self, integration_client, clear_database): + response = integration_client.post( + f"{BASE_URL}{DEFAULT_QUERY}", + json={"name": "missing words"}, + ) + + assert response.status_code == 422 + + +class TestListBanLists(BaseBanListTest): + + def test_list_success(self, integration_client, clear_database): + self.create(integration_client, "minimal") + self.create(integration_client, "health") + + response = self.list(integration_client) + + assert response.status_code == 200 + assert len(response.json()) == 2 + + def test_filter_by_domain(self, integration_client, clear_database): + self.create(integration_client, "health") + self.create(integration_client, "edu") + + response = self.list(integration_client, domain="health") + + data = response.json() + assert len(data) == 1 + assert data[0]["domain"] == "health" + + def test_list_empty(self, integration_client, clear_database): + response = self.list(integration_client) + assert response.json() == [] + + +class TestPublicAccess(BaseBanListTest): + + def test_public_visible_to_other_org(self, integration_client, clear_database): + create_resp = self.create(integration_client, "public") + ban_id = create_resp.json()["id"] + + response = self.get(integration_client, ban_id, org=999, project=999) + + # public lists should still be readable + assert response.status_code == 200 + + +class TestGetBanList(BaseBanListTest): + + def test_get_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.get(integration_client, ban_id) + + assert response.status_code == 200 + + def test_get_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + response = self.get(integration_client, fake) + + assert response.status_code == 404 + + def test_get_wrong_owner_private(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.get(integration_client, ban_id, org=2, project=2) + + assert response.status_code in (403, 404) + + +class TestUpdateBanList(BaseBanListTest): + + def test_update_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.update( + integration_client, + ban_id, + {"banned_words": ["bad", "worse"]}, + ) + + assert response.status_code == 200 + assert response.json()["banned_words"] == ["bad", "worse"] + + def test_partial_update(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.update(integration_client, ban_id, {"name": "updated"}) + + assert response.json()["name"] == "updated" + + def test_update_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.update(integration_client, fake, {"name": "x"}) + assert response.status_code == 404 + + +class TestDeleteBanList(BaseBanListTest): + + def test_delete_success(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = self.delete(integration_client, ban_id) + + assert response.status_code == 200 + assert response.json()["success"] is True + + def test_delete_not_found(self, integration_client, clear_database): + fake = uuid.uuid4() + + response = self.delete(integration_client, fake) + assert response.status_code == 404 + + def test_delete_wrong_owner(self, integration_client, clear_database): + create_resp = self.create(integration_client, "minimal") + ban_id = create_resp.json()["id"] + + response = integration_client.delete( + f"{BASE_URL}{ban_id}/?org_id=999&project_id=999" + ) + + assert response.status_code in (403, 404)