diff --git a/backend/app/alembic/versions/044_add_stt_evaluation_tables.py b/backend/app/alembic/versions/044_add_stt_evaluation_tables.py new file mode 100644 index 000000000..be474224f --- /dev/null +++ b/backend/app/alembic/versions/044_add_stt_evaluation_tables.py @@ -0,0 +1,462 @@ +"""add stt evaluation tables + +Revision ID: 044 +Revises: 043 +Create Date: 2026-01-28 12:00:00.000000 + +""" + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "044" +down_revision = "043" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add type and language columns to evaluation_dataset table + op.add_column( + "evaluation_dataset", + sa.Column( + "type", + sa.String(length=20), + nullable=False, + server_default="text", + comment="Evaluation type: text, stt, or tts", + ), + ) + op.add_column( + "evaluation_dataset", + sa.Column( + "language_id", + sa.Integer(), + nullable=True, + comment="Foreign key to global.languages table", + ), + ) + op.create_foreign_key( + "fk_evaluation_dataset_language_id", + "evaluation_dataset", + "languages", + ["language_id"], + ["id"], + referent_schema="global", + ) + + # Add type, language, and providers columns to evaluation_run table + op.add_column( + "evaluation_run", + sa.Column( + "type", + sa.String(length=20), + nullable=False, + server_default="text", + comment="Evaluation type: text, stt, or tts", + ), + ) + op.add_column( + "evaluation_run", + sa.Column( + "language_id", + sa.Integer(), + nullable=True, + comment="Foreign key to global.languages table", + ), + ) + op.create_foreign_key( + "fk_evaluation_run_language_id", + "evaluation_run", + "languages", + ["language_id"], + ["id"], + referent_schema="global", + ) + op.add_column( + "evaluation_run", + sa.Column( + "providers", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="List of STT/TTS providers used (e.g., ['gemini-2.5-pro'])", + ), + ) + + # Create file table for storing uploaded file metadata + op.create_table( + "file", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the file", + ), + sa.Column( + "object_store_url", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + comment="S3 URL where the file is stored", + ), + sa.Column( + "filename", + sa.String(length=255), + nullable=False, + comment="Original filename as uploaded", + ), + sa.Column( + "size_bytes", + sa.Integer(), + nullable=False, + comment="File size in bytes", + ), + sa.Column( + "content_type", + sa.String(length=100), + nullable=False, + comment="MIME type of the file (e.g., audio/mp3)", + ), + sa.Column( + "file_type", + sa.String(length=20), + nullable=False, + server_default="other", + comment="Type of file: audio, document, image, other", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the file was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the file was last updated", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "idx_file_org_project", + "file", + ["organization_id", "project_id"], + unique=False, + ) + op.create_index( + "idx_file_type", + "file", + ["file_type"], + unique=False, + ) + + # Create stt_sample table with file_id reference + op.create_table( + "stt_sample", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the STT sample", + ), + sa.Column( + "file_id", + sa.Integer(), + nullable=False, + comment="Reference to the uploaded audio file in file table", + ), + sa.Column( + "language_id", + sa.Integer(), + nullable=True, + comment="Foreign key to global.languages table", + ), + sa.Column( + "ground_truth", + sa.Text(), + nullable=True, + comment="Reference transcription for comparison (optional)", + ), + sa.Column( + "sample_metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + server_default=sa.text("'{}'::jsonb"), + comment="Additional metadata (format, bitrate, original filename, etc.)", + ), + sa.Column( + "dataset_id", + sa.Integer(), + nullable=False, + comment="Reference to the parent evaluation dataset", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the sample was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the sample was last updated", + ), + sa.ForeignKeyConstraint( + ["file_id"], + ["file.id"], + name="fk_stt_sample_file_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["language_id"], + ["global.languages.id"], + name="fk_stt_sample_language_id", + ), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["evaluation_dataset.id"], + name="fk_stt_sample_dataset_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_stt_sample_file_id", + "stt_sample", + ["file_id"], + unique=False, + ) + op.create_index( + "ix_stt_sample_dataset_id", + "stt_sample", + ["dataset_id"], + unique=False, + ) + op.create_index( + "idx_stt_sample_org_project", + "stt_sample", + ["organization_id", "project_id"], + unique=False, + ) + + # Create stt_result table + op.create_table( + "stt_result", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the STT result", + ), + sa.Column( + "transcription", + sa.Text(), + nullable=True, + comment="Generated transcription from STT provider", + ), + sa.Column( + "provider", + sa.String(length=50), + nullable=False, + comment="STT provider used (e.g., gemini-2.5-pro)", + ), + sa.Column( + "status", + sa.String(length=20), + nullable=False, + server_default="pending", + comment="Result status: pending, completed, failed", + ), + sa.Column( + "score", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Evaluation metrics (e.g., wer, cer, mer, wil) - extensible for future metrics", + ), + sa.Column( + "is_correct", + sa.Boolean(), + nullable=True, + comment="Human feedback: transcription correctness (null=not reviewed)", + ), + sa.Column( + "comment", + sa.Text(), + nullable=True, + comment="Human feedback comment", + ), + sa.Column( + "error_message", + sa.Text(), + nullable=True, + comment="Error message if transcription failed", + ), + sa.Column( + "stt_sample_id", + sa.Integer(), + nullable=False, + comment="Reference to the STT sample", + ), + sa.Column( + "evaluation_run_id", + sa.Integer(), + nullable=False, + comment="Reference to the evaluation run", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was last updated", + ), + sa.ForeignKeyConstraint( + ["stt_sample_id"], + ["stt_sample.id"], + name="fk_stt_result_sample_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["evaluation_run_id"], + ["evaluation_run.id"], + name="fk_stt_result_run_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_stt_result_sample_id", + "stt_result", + ["stt_sample_id"], + unique=False, + ) + op.create_index( + "ix_stt_result_run_id", + "stt_result", + ["evaluation_run_id"], + unique=False, + ) + op.create_index( + "idx_stt_result_feedback", + "stt_result", + ["evaluation_run_id", "is_correct"], + unique=False, + ) + op.create_index( + "idx_stt_result_status", + "stt_result", + ["evaluation_run_id", "status"], + unique=False, + ) + + +def downgrade(): + # Drop stt_result table + op.drop_index("idx_stt_result_status", table_name="stt_result") + op.drop_index("idx_stt_result_feedback", table_name="stt_result") + op.drop_index("ix_stt_result_run_id", table_name="stt_result") + op.drop_index("ix_stt_result_sample_id", table_name="stt_result") + op.drop_table("stt_result") + + # Drop stt_sample table + op.drop_index("idx_stt_sample_org_project", table_name="stt_sample") + op.drop_index("ix_stt_sample_dataset_id", table_name="stt_sample") + op.drop_index("ix_stt_sample_file_id", table_name="stt_sample") + op.drop_table("stt_sample") + + # Drop file table + op.drop_index("idx_file_type", table_name="file") + op.drop_index("idx_file_org_project", table_name="file") + op.drop_table("file") + + # Remove columns from evaluation_run table + op.drop_column("evaluation_run", "providers") + op.drop_constraint( + "fk_evaluation_run_language_id", "evaluation_run", type_="foreignkey" + ) + op.drop_column("evaluation_run", "language_id") + op.drop_column("evaluation_run", "type") + + # Remove columns from evaluation_dataset table + op.drop_constraint( + "fk_evaluation_dataset_language_id", + "evaluation_dataset", + type_="foreignkey", + ) + op.drop_column("evaluation_dataset", "language_id") + op.drop_column("evaluation_dataset", "type") diff --git a/backend/app/api/docs/stt_evaluation/create_dataset.md b/backend/app/api/docs/stt_evaluation/create_dataset.md new file mode 100644 index 000000000..c17b3a06c --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/create_dataset.md @@ -0,0 +1,5 @@ +Create a new STT evaluation dataset with audio samples. + +Each sample requires: +- **object_store_url**: S3 URL of the audio file (from /evaluations/stt/files/audio endpoint) +- **ground_truth**: Reference transcription (optional, for WER/CER metrics) diff --git a/backend/app/api/docs/stt_evaluation/get_dataset.md b/backend/app/api/docs/stt_evaluation/get_dataset.md new file mode 100644 index 000000000..e965504e7 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/get_dataset.md @@ -0,0 +1 @@ +Get an STT dataset with its samples. diff --git a/backend/app/api/docs/stt_evaluation/get_result.md b/backend/app/api/docs/stt_evaluation/get_result.md new file mode 100644 index 000000000..728710ec5 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/get_result.md @@ -0,0 +1 @@ +Get a single STT transcription result. diff --git a/backend/app/api/docs/stt_evaluation/get_run.md b/backend/app/api/docs/stt_evaluation/get_run.md new file mode 100644 index 000000000..ce88c37d4 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/get_run.md @@ -0,0 +1 @@ +Get an STT evaluation run with its results. diff --git a/backend/app/api/docs/stt_evaluation/list_datasets.md b/backend/app/api/docs/stt_evaluation/list_datasets.md new file mode 100644 index 000000000..32a3671ad --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/list_datasets.md @@ -0,0 +1 @@ +List all STT evaluation datasets for the current project. diff --git a/backend/app/api/docs/stt_evaluation/list_runs.md b/backend/app/api/docs/stt_evaluation/list_runs.md new file mode 100644 index 000000000..40ffa1ff2 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/list_runs.md @@ -0,0 +1 @@ +List all STT evaluation runs for the current project. diff --git a/backend/app/api/docs/stt_evaluation/start_evaluation.md b/backend/app/api/docs/stt_evaluation/start_evaluation.md new file mode 100644 index 000000000..cfc2e03ed --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/start_evaluation.md @@ -0,0 +1,8 @@ +Start an STT evaluation run on a dataset. + +The evaluation will: +1. Process each audio sample through the specified providers +2. Generate transcriptions using Gemini Batch API +3. Store results for human review + +**Supported providers:** gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash diff --git a/backend/app/api/docs/stt_evaluation/update_feedback.md b/backend/app/api/docs/stt_evaluation/update_feedback.md new file mode 100644 index 000000000..71a44ce81 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/update_feedback.md @@ -0,0 +1,5 @@ +Update human feedback on an STT transcription result. + +**Fields:** +- **is_correct**: Boolean indicating if the transcription is correct +- **comment**: Optional feedback comment explaining issues or observations diff --git a/backend/app/api/docs/stt_evaluation/upload_audio.md b/backend/app/api/docs/stt_evaluation/upload_audio.md new file mode 100644 index 000000000..e4fb4429b --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/upload_audio.md @@ -0,0 +1,7 @@ +Upload a single audio file to S3 for STT evaluation. + +**Supported formats:** mp3, wav, flac, m4a, ogg, webm + +**Maximum file size:** 200 MB + +Returns the S3 URL which can be used when creating an STT dataset. diff --git a/backend/app/api/main.py b/backend/app/api/main.py index d62cb1c79..65fd0dc3e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -26,6 +26,7 @@ collection_job, ) from app.api.routes.evaluations import dataset as evaluation_dataset, evaluation +from app.api.routes import stt_evaluations from app.core.config import settings api_router = APIRouter() @@ -40,6 +41,7 @@ api_router.include_router(doc_transformation_job.router) api_router.include_router(evaluation_dataset.router) api_router.include_router(evaluation.router) +api_router.include_router(stt_evaluations.router) api_router.include_router(languages.router) api_router.include_router(llm.router) api_router.include_router(login.router) diff --git a/backend/app/api/routes/stt_evaluations/__init__.py b/backend/app/api/routes/stt_evaluations/__init__.py new file mode 100644 index 000000000..fa71b5c0d --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/__init__.py @@ -0,0 +1,5 @@ +"""STT Evaluation API routes.""" + +from .router import router + +__all__ = ["router"] diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py new file mode 100644 index 000000000..b7a3864d5 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -0,0 +1,193 @@ +"""STT dataset API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.file import get_files_by_ids +from app.crud.language import get_language_by_id +from app.crud.stt_evaluations import ( + get_stt_dataset_by_id, + list_stt_datasets, + get_samples_by_dataset_id, +) +from app.models.stt_evaluation import ( + STTDatasetCreate, + STTDatasetPublic, + STTDatasetWithSamples, + STTSamplePublic, +) +from app.services.stt_evaluations.dataset import upload_stt_dataset +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/datasets", + response_model=APIResponse[STTDatasetPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Create STT dataset", + description=load_description("stt_evaluation/create_dataset.md"), +) +def create_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_create: STTDatasetCreate = Body(...), +) -> APIResponse[STTDatasetPublic]: + """Create an STT evaluation dataset.""" + # Validate language_id if provided + if dataset_create.language_id is not None: + language = get_language_by_id( + session=_session, language_id=dataset_create.language_id + ) + if not language: + raise HTTPException( + status_code=400, detail="Invalid language_id: language not found" + ) + + dataset, samples = upload_stt_dataset( + session=_session, + name=dataset_create.name, + samples=dataset_create.samples, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + description=dataset_create.description, + language_id=dataset_create.language_id, + ) + + return APIResponse.success_response( + data=STTDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + sample_count=len(samples), + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + ) + + +@router.get( + "/datasets", + response_model=APIResponse[list[STTDatasetPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List STT datasets", + description=load_description("stt_evaluation/list_datasets.md"), +) +def list_datasets( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[STTDatasetPublic]]: + """List STT evaluation datasets.""" + datasets, total = list_stt_datasets( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=datasets, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/datasets/{dataset_id}", + response_model=APIResponse[STTDatasetWithSamples], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get STT dataset", + description=load_description("stt_evaluation/get_dataset.md"), +) +def get_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int, + include_samples: bool = Query(True, description="Include samples in response"), + sample_limit: int = Query(100, ge=1, le=1000, description="Max samples to return"), + sample_offset: int = Query(0, ge=0, description="Sample offset"), +) -> APIResponse[STTDatasetWithSamples]: + """Get an STT evaluation dataset.""" + dataset = get_stt_dataset_by_id( + session=_session, + dataset_id=dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + samples = [] + samples_total = (dataset.dataset_metadata or {}).get("sample_count", 0) + + if include_samples: + sample_records = get_samples_by_dataset_id( + session=_session, + dataset_id=dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=sample_limit, + offset=sample_offset, + ) + + # Fetch file records to get object_store_url + file_ids = [s.file_id for s in sample_records] + file_records = get_files_by_ids( + session=_session, + file_ids=file_ids, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + file_map = {f.id: f for f in file_records} + + samples = [ + STTSamplePublic( + id=s.id, + file_id=s.file_id, + object_store_url=file_map.get(s.file_id).object_store_url + if s.file_id in file_map + else None, + language_id=s.language_id, + ground_truth=s.ground_truth, + sample_metadata=s.sample_metadata, + dataset_id=s.dataset_id, + organization_id=s.organization_id, + project_id=s.project_id, + inserted_at=s.inserted_at, + updated_at=s.updated_at, + ) + for s in sample_records + ] + + return APIResponse.success_response( + data=STTDatasetWithSamples( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + samples=samples, + ), + metadata={"samples_total": samples_total}, + ) diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py new file mode 100644 index 000000000..23b1a9655 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -0,0 +1,261 @@ +"""STT evaluation run API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.language import get_language_by_id +from app.crud.stt_evaluations import ( + create_stt_run, + create_stt_results, + get_results_by_run_id, + get_samples_by_dataset_id, + get_stt_dataset_by_id, + get_stt_run_by_id, + list_stt_runs, + start_stt_evaluation_batch, + update_stt_run, +) +from app.models.stt_evaluation import ( + STTEvaluationRunCreate, + STTEvaluationRunPublic, + STTEvaluationRunWithResults, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/runs", + response_model=APIResponse[STTEvaluationRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Start STT evaluation", + description=load_description("stt_evaluation/start_evaluation.md"), +) +def start_stt_evaluation( + _session: SessionDep, + auth_context: AuthContextDep, + run_create: STTEvaluationRunCreate = Body(...), +) -> APIResponse[STTEvaluationRunPublic]: + """Start an STT evaluation run.""" + logger.info( + f"[start_stt_evaluation] Starting STT evaluation | " + f"run_name: {run_create.run_name}, dataset_id: {run_create.dataset_id}, " + f"providers: {run_create.providers}" + ) + + # Validate dataset exists + dataset = get_stt_dataset_by_id( + session=_session, + dataset_id=run_create.dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + sample_count = (dataset.dataset_metadata or {}).get("sample_count", 0) + + if sample_count == 0: + raise HTTPException(status_code=400, detail="Dataset has no samples") + + # Validate language_id if provided + language_id = run_create.language_id or dataset.language_id + if language_id is not None: + language = get_language_by_id(session=_session, language_id=language_id) + if not language: + raise HTTPException( + status_code=400, detail="Invalid language_id: language not found" + ) + + # Create run record + run = create_stt_run( + session=_session, + run_name=run_create.run_name, + dataset_id=run_create.dataset_id, + dataset_name=dataset.name, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + providers=run_create.providers, + language_id=language_id, + total_items=sample_count * len(run_create.providers), + ) + + # Get samples for the dataset + samples = get_samples_by_dataset_id( + session=_session, + dataset_id=run_create.dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + # Create result records for each sample and provider + results = create_stt_results( + session=_session, + samples=samples, + evaluation_run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + providers=run_create.providers, + ) + + sample_to_result = {r.stt_sample_id: r.id for r in results} + + try: + batch_result = start_stt_evaluation_batch( + session=_session, + run=run, + samples=samples, + sample_to_result=sample_to_result, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + logger.info( + f"[start_stt_evaluation] STT evaluation batch submitted | " + f"run_id: {run.id}, batch_jobs: {batch_result.get('batch_jobs', {}).keys()}" + ) + except Exception as e: + logger.error( + f"[start_stt_evaluation] Batch submission failed | " + f"run_id: {run.id}, error: {str(e)}" + ) + update_stt_run( + session=_session, + run_id=run.id, + status="failed", + error_message=str(e), + ) + raise HTTPException(status_code=500, detail=f"Batch submission failed: {e}") + + # Refresh run to get updated status + run = get_stt_run_by_id( + session=_session, + run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response( + data=STTEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + providers=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + ) + + +@router.get( + "/runs", + response_model=APIResponse[list[STTEvaluationRunPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List STT evaluation runs", + description=load_description("stt_evaluation/list_runs.md"), +) +def list_stt_evaluation_runs( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int | None = Query(None, description="Filter by dataset ID"), + status: str | None = Query(None, description="Filter by status"), + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[STTEvaluationRunPublic]]: + """List STT evaluation runs.""" + runs, total = list_stt_runs( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + dataset_id=dataset_id, + status=status, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=runs, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/runs/{run_id}", + response_model=APIResponse[STTEvaluationRunWithResults], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get STT evaluation run", + description=load_description("stt_evaluation/get_run.md"), +) +def get_stt_evaluation_run( + _session: SessionDep, + auth_context: AuthContextDep, + run_id: int, + include_results: bool = Query(True, description="Include results in response"), + result_limit: int = Query(100, ge=1, le=1000, description="Max results to return"), + result_offset: int = Query(0, ge=0, description="Result offset"), + provider: str | None = Query(None, description="Filter results by provider"), + status: str | None = Query(None, description="Filter results by status"), +) -> APIResponse[STTEvaluationRunWithResults]: + """Get an STT evaluation run with results.""" + run = get_stt_run_by_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not run: + raise HTTPException(status_code=404, detail="Evaluation run not found") + + results = [] + results_total = 0 + + if include_results: + results, results_total = get_results_by_run_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + provider=provider, + status=status, + limit=result_limit, + offset=result_offset, + ) + + return APIResponse.success_response( + data=STTEvaluationRunWithResults( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + providers=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + results=results, + results_total=results_total, + ), + metadata={"results_total": results_total}, + ) diff --git a/backend/app/api/routes/stt_evaluations/files.py b/backend/app/api/routes/stt_evaluations/files.py new file mode 100644 index 000000000..e3cb50c10 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/files.py @@ -0,0 +1,43 @@ +"""Audio file upload API routes for STT evaluation.""" + +import logging + +from fastapi import APIRouter, Depends, File, UploadFile + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models.stt_evaluation import AudioUploadResponse +from app.services.stt_evaluations.audio import upload_audio_file +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/files/audio", + response_model=APIResponse[AudioUploadResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Upload audio file", + description=load_description("stt_evaluation/upload_audio.md"), +) +def upload_audio( + _session: SessionDep, + auth_context: AuthContextDep, + file: UploadFile = File(..., description="Audio file to upload"), +) -> APIResponse[AudioUploadResponse]: + """Upload an audio file for STT evaluation.""" + logger.info( + f"[upload_audio] Uploading audio file | " + f"project_id: {auth_context.project_.id}, filename: {file.filename}" + ) + + result = upload_audio_file( + session=_session, + file=file, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response(data=result) diff --git a/backend/app/api/routes/stt_evaluations/result.py b/backend/app/api/routes/stt_evaluations/result.py new file mode 100644 index 000000000..b2bc48bb0 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/result.py @@ -0,0 +1,124 @@ +"""STT result feedback API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.stt_evaluations import ( + get_stt_result_by_id, + update_human_feedback, +) +from app.models.stt_evaluation import ( + STTFeedbackUpdate, + STTResultPublic, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.patch( + "/results/{result_id}", + response_model=APIResponse[STTResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Update human feedback", + description=load_description("stt_evaluation/update_feedback.md"), +) +def update_result_feedback( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, + feedback: STTFeedbackUpdate = Body(...), +) -> APIResponse[STTResultPublic]: + """Update human feedback on an STT result.""" + logger.info( + f"[update_result_feedback] Updating feedback | " + f"result_id: {result_id}, is_correct: {feedback.is_correct}" + ) + + # Verify result exists and belongs to this project + existing = get_stt_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not existing: + raise HTTPException(status_code=404, detail="Result not found") + + # Update feedback + result = update_human_feedback( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + is_correct=feedback.is_correct, + comment=feedback.comment, + ) + + return APIResponse.success_response( + data=STTResultPublic( + id=result.id, + transcription=result.transcription, + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + stt_sample_id=result.stt_sample_id, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + ) + + +@router.get( + "/results/{result_id}", + response_model=APIResponse[STTResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get STT result", + description=load_description("stt_evaluation/get_result.md"), +) +def get_result( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, +) -> APIResponse[STTResultPublic]: + """Get an STT result by ID.""" + result = get_stt_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + return APIResponse.success_response( + data=STTResultPublic( + id=result.id, + transcription=result.transcription, + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + stt_sample_id=result.stt_sample_id, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + ) diff --git a/backend/app/api/routes/stt_evaluations/router.py b/backend/app/api/routes/stt_evaluations/router.py new file mode 100644 index 000000000..07a958f17 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/router.py @@ -0,0 +1,13 @@ +"""Main router for STT evaluation API routes.""" + +from fastapi import APIRouter + +from . import dataset, evaluation, files, result + +router = APIRouter(prefix="/evaluations/stt", tags=["STT Evaluation"]) + +# Include all sub-routers +router.include_router(files.router) +router.include_router(dataset.router) +router.include_router(evaluation.router) +router.include_router(result.router) diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 0cd10bd51..382fcc4ff 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -1,6 +1,7 @@ """Batch processing infrastructure for LLM providers.""" from .base import BatchProvider +from .gemini import BatchJobState, GeminiBatchProvider, create_stt_batch_requests from .openai import OpenAIBatchProvider from .operations import ( download_batch_results, @@ -12,7 +13,10 @@ __all__ = [ "BatchProvider", + "BatchJobState", + "GeminiBatchProvider", "OpenAIBatchProvider", + "create_stt_batch_requests", "start_batch_job", "download_batch_results", "process_completed_batch", diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py new file mode 100644 index 000000000..1dd28a952 --- /dev/null +++ b/backend/app/core/batch/gemini.py @@ -0,0 +1,443 @@ +"""Gemini batch provider implementation.""" + +import json +import logging +import os +import tempfile +import time +from enum import Enum +from typing import Any + +from google import genai +from google.genai import types + +from app.core.storage_utils import get_mime_from_url + +from .base import BatchProvider + +logger = logging.getLogger(__name__) + + +class BatchJobState(str, Enum): + """Gemini batch job states.""" + + PENDING = "JOB_STATE_PENDING" + RUNNING = "JOB_STATE_RUNNING" + SUCCEEDED = "JOB_STATE_SUCCEEDED" + FAILED = "JOB_STATE_FAILED" + CANCELLED = "JOB_STATE_CANCELLED" + EXPIRED = "JOB_STATE_EXPIRED" + + +# Terminal states that indicate the batch is done +_TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + +# Failed terminal states +_FAILED_STATES = { + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +class GeminiBatchProvider(BatchProvider): + """Gemini implementation of the BatchProvider interface. + + Supports both inline requests and JSONL file-based batch submissions. + Each JSONL line follows the Gemini format: + {"key": "request-1", "request": {"contents": [{"parts": [...]}]}} + """ + + DEFAULT_MODEL = "models/gemini-2.5-pro" + + def __init__(self, client: genai.Client, model: str | None = None) -> None: + """Initialize the Gemini batch provider. + + Args: + client: Configured Gemini client + model: Model to use (defaults to gemini-2.5-pro) + """ + self._client = client + self._model = model or self.DEFAULT_MODEL + + def create_batch( + self, jsonl_data: list[dict[str, Any]], config: dict[str, Any] + ) -> dict[str, Any]: + """Upload JSONL file and create a batch job with Gemini. + + Args: + jsonl_data: List of dictionaries in Gemini JSONL format. + Each dict should have the structure: + {"key": "request-1", "request": {"contents": [{"parts": [...]}]}} + config: Provider-specific configuration with: + - display_name: Optional batch display name + - model: Optional model override + + Returns: + Dictionary containing: + - provider_batch_id: Gemini batch job name + - provider_file_id: Uploaded JSONL file name + - provider_status: Initial status from Gemini + - total_items: Number of items in the batch + """ + model = config.get("model", self._model) + display_name = config.get("display_name", f"batch-{int(time.time())}") + + logger.info( + f"[create_batch] Creating Gemini batch | items={len(jsonl_data)} | " + f"model={model} | display_name={display_name}" + ) + + try: + # Create JSONL content + jsonl_content = "\n".join(json.dumps(item) for item in jsonl_data) + + # Upload JSONL file to Gemini File API + uploaded_file = self.upload_file(jsonl_content, purpose="batch") + + logger.info( + f"[create_batch] Uploaded JSONL file | file_name={uploaded_file}" + ) + + # Create batch job using uploaded file + batch_job = self._client.batches.create( + model=model, + src=uploaded_file, + config={"display_name": display_name}, + ) + + initial_state = batch_job.state.name if batch_job.state else "UNKNOWN" + + result = { + "provider_batch_id": batch_job.name, + "provider_file_id": uploaded_file, + "provider_status": initial_state, + "total_items": len(jsonl_data), + } + + logger.info( + f"[create_batch] Created Gemini batch | batch_id={batch_job.name} | " + f"file_id={uploaded_file} | status={initial_state} | items={len(jsonl_data)}" + ) + + return result + + except Exception as e: + logger.error(f"[create_batch] Failed to create Gemini batch | {e}") + raise + + def get_batch_status(self, batch_id: str) -> dict[str, Any]: + """Poll Gemini for batch job status. + + Args: + batch_id: Gemini batch job name + + Returns: + Dictionary containing: + - provider_status: Current Gemini state + - provider_output_file_id: batch_id (used to fetch results) + - error_message: Error message (if failed) + """ + logger.info( + f"[get_batch_status] Polling Gemini batch status | batch_id={batch_id}" + ) + + try: + batch_job = self._client.batches.get(name=batch_id) + state = batch_job.state.name if batch_job.state else "UNKNOWN" + + result: dict[str, Any] = { + "provider_status": state, + # Gemini uses the same batch name to fetch results + "provider_output_file_id": batch_id, + } + + if state in _FAILED_STATES: + result["error_message"] = f"Batch {state}" + + logger.info( + f"[get_batch_status] Gemini batch status | batch_id={batch_id} | " + f"status={state}" + ) + + return result + + except Exception as e: + logger.error( + f"[get_batch_status] Failed to poll Gemini batch status | " + f"batch_id={batch_id} | {e}" + ) + raise + + def download_batch_results( + self, output_file_id: str, request_keys: list[str] | None = None + ) -> list[dict[str, Any]]: + """Download and parse batch results from Gemini. + + Gemini returns results as a downloadable JSONL file where each line + contains the key and response. + + Args: + output_file_id: Gemini batch job name (used to fetch the batch) + request_keys: Deprecated, kept for interface compatibility. + + Returns: + List of result dictionaries, each containing: + - custom_id: Item key from input + - response: Dict with "text" key containing the generated text + - error: Error info (if item failed), None otherwise + """ + logger.info( + f"[download_batch_results] Downloading Gemini batch results | " + f"batch_id={output_file_id}" + ) + + try: + batch_job = self._client.batches.get(name=output_file_id) + state = batch_job.state.name if batch_job.state else "UNKNOWN" + + if state != BatchJobState.SUCCEEDED.value: + raise ValueError(f"Batch job not complete. Current state: {state}") + + results: list[dict[str, Any]] = [] + + # Handle file-based results (keys are included in each response line) + if ( + batch_job.dest + and hasattr(batch_job.dest, "file_name") + and batch_job.dest.file_name + ): + file_content = self.download_file(batch_job.dest.file_name) + lines = file_content.strip().split("\n") + for i, line in enumerate(lines): + try: + parsed = json.loads(line) + custom_id = parsed.get("key", str(i)) + + # Extract text from response + response_obj = parsed.get("response") + if response_obj: + text = self._extract_text_from_response_dict(response_obj) + results.append( + { + "custom_id": custom_id, + "response": {"text": text}, + "error": None, + } + ) + elif parsed.get("error"): + results.append( + { + "custom_id": custom_id, + "response": None, + "error": str(parsed["error"]), + } + ) + except json.JSONDecodeError as e: + logger.error( + f"[download_batch_results] Failed to parse JSON | " + f"line={i + 1} | {e}" + ) + continue + + logger.info( + f"[download_batch_results] Downloaded Gemini batch results | " + f"batch_id={output_file_id} | results={len(results)}" + ) + + return results + + except Exception as e: + logger.error( + f"[download_batch_results] Failed to download Gemini batch results | " + f"batch_id={output_file_id} | {e}" + ) + raise + + @staticmethod + def _extract_text_from_response_dict(response: dict[str, Any]) -> str: + """Extract text content from a Gemini response dictionary. + + Args: + response: Gemini response as a dictionary + + Returns: + str: Extracted text + """ + # Try direct text field first + if "text" in response: + return response["text"] + + # Extract from candidates structure + text = "" + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + if "text" in part: + text += part["text"] + return text + + def upload_file(self, content: str, purpose: str = "batch") -> str: + """Upload a JSONL file to Gemini Files API. + + Args: + content: File content (JSONL string) + purpose: Purpose of the file (unused for Gemini, kept for interface) + + Returns: + Gemini file name (e.g., "files/xxx") + """ + logger.info(f"[upload_file] Uploading file to Gemini | bytes={len(content)}") + + try: + with tempfile.NamedTemporaryFile( + suffix=".jsonl", delete=False, mode="w", encoding="utf-8" + ) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + try: + uploaded_file = self._client.files.upload( + file=tmp_path, + config=types.UploadFileConfig( + display_name=f"batch-input-{int(time.time())}", + mime_type="jsonl", + ), + ) + + logger.info( + f"[upload_file] Uploaded file to Gemini | " + f"file_name={uploaded_file.name}" + ) + + return uploaded_file.name + + finally: + os.unlink(tmp_path) + + except Exception as e: + logger.error(f"[upload_file] Failed to upload file to Gemini | {e}") + raise + + def download_file(self, file_id: str) -> str: + """Download a file from Gemini Files API. + + Args: + file_id: Gemini file name (e.g., "files/xxx") + + Returns: + File content as UTF-8 string + """ + logger.info(f"[download_file] Downloading file from Gemini | file_id={file_id}") + + try: + file_content = self._client.files.download(file=file_id) + content = file_content.decode("utf-8") + + logger.info( + f"[download_file] Downloaded file from Gemini | " + f"file_id={file_id} | bytes={len(content)}" + ) + + return content + + except Exception as e: + logger.error( + f"[download_file] Failed to download file from Gemini | " + f"file_id={file_id} | {e}" + ) + raise + + @staticmethod + def _extract_text_from_response(response: Any) -> str: + """Extract text content from a Gemini response object. + + Args: + response: Gemini GenerateContentResponse + + Returns: + str: Extracted text + """ + if hasattr(response, "text"): + return response.text + + text = "" + if hasattr(response, "candidates"): + for candidate in response.candidates: + if hasattr(candidate, "content"): + for part in candidate.content.parts: + if hasattr(part, "text"): + text += part.text + return text + + +def create_stt_batch_requests( + signed_urls: list[str], + prompt: str, + keys: list[str] | None = None, +) -> list[dict[str, Any]]: + """ + Create batch API requests for Gemini STT using signed URLs. + + This function generates request payloads in Gemini's JSONL batch format + using signed URLs directly. MIME types are automatically detected from the URL path. + + Args: + signed_urls: List of signed URLs pointing to audio files + prompt: Transcription prompt/instructions for the model + keys: Optional list of custom IDs for tracking results. If not provided, + uses 0-indexed integers as strings. + + Returns: + List of batch request dicts in Gemini JSONL format: + {"key": "sample-1", "request": {"contents": [...]}} + + Example: + >>> urls = ["https://bucket.s3.amazonaws.com/audio.mp3?..."] + >>> prompt = "Transcribe this audio file." + >>> requests = create_stt_batch_requests(urls, prompt, keys=["sample-1"]) + >>> provider.create_batch(requests, {"display_name": "stt-batch"}) + """ + if keys is not None and len(keys) != len(signed_urls): + raise ValueError( + f"Length of keys ({len(keys)}) must match signed_urls ({len(signed_urls)})" + ) + + requests = [] + for i, url in enumerate(signed_urls): + mime_type = get_mime_from_url(url) + if mime_type is None: + logger.warning( + f"[create_stt_batch_requests] Could not determine MIME type for URL | " + f"index={i} | defaulting to audio/mpeg" + ) + mime_type = "audio/mpeg" + + # Use provided key or generate from index + key = keys[i] if keys is not None else str(i) + + # Gemini JSONL format: {"key": "...", "request": {...}} + request = { + "key": key, + "request": { + "contents": [ + { + "parts": [ + {"text": prompt}, + {"file_data": {"mime_type": mime_type, "file_uri": url}}, + ], + "role": "user", + } + ] + }, + } + requests.append(request) + + logger.info(f"[create_stt_batch_requests] Created {len(requests)} batch requests") + + return requests diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index a3247b74b..727380726 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -213,13 +213,19 @@ def get_file_size_kb(self, url: str) -> float: ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err + # Maximum allowed expiry for signed URLs (24 hours) + MAX_SIGNED_URL_EXPIRY = 86400 + def get_signed_url(self, url: str, expires_in: int = 3600) -> str: """ Generate a signed S3 URL for the given file. :param url: S3 url (e.g., s3://bucket/key) - :param expires_in: Expiry time in seconds (default: 1 hour) + :param expires_in: Expiry time in seconds (default: 1 hour, max: 24 hours) :return: Signed URL as string """ + # Cap expiry at maximum allowed value to prevent excessively long-lived URLs + expires_in = min(expires_in, self.MAX_SIGNED_URL_EXPIRY) + name = SimpleStorageName.from_url(url) try: signed_url = self.aws.client.generate_presigned_url( diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index dfaae233a..01769e8c8 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -12,6 +12,7 @@ class Provider(str, Enum): OPENAI = "openai" AWS = "aws" LANGFUSE = "langfuse" + GEMINI = "gemini" @dataclass @@ -30,6 +31,7 @@ class ProviderConfig: Provider.LANGFUSE: ProviderConfig( required_fields=["secret_key", "public_key", "host"] ), + Provider.GEMINI: ProviderConfig(required_fields=["api_key"]), } diff --git a/backend/app/core/storage_utils.py b/backend/app/core/storage_utils.py index 63830d7d0..aa1b72199 100644 --- a/backend/app/core/storage_utils.py +++ b/backend/app/core/storage_utils.py @@ -5,12 +5,13 @@ to cloud object storage, abstracting away provider-specific details. """ -import io import json import logging +import mimetypes from datetime import datetime from io import BytesIO from pathlib import Path +from urllib.parse import unquote, urlparse from starlette.datastructures import Headers, UploadFile @@ -19,20 +20,44 @@ logger = logging.getLogger(__name__) -def upload_csv_to_object_store( +def get_mime_from_url(url: str) -> str | None: + """ + Extract MIME type from a URL by parsing its path component. + + Works with signed URLs by ignoring query parameters and extracting + the file extension from the path. + + Args: + url: URL string (can include query parameters like signed URLs) + + Returns: + MIME type string (e.g., 'audio/mpeg') or None if unable to determine + """ + parsed = urlparse(url) + path = unquote(parsed.path) + mime_type, _ = mimetypes.guess_type(path) + return mime_type + + +def upload_to_object_store( storage: CloudStorage, - csv_content: bytes, + content: bytes, filename: str, - subdirectory: str = "datasets", + subdirectory: str, + content_type: str = "application/octet-stream", ) -> str | None: """ - Upload CSV content to object store. + Upload content to object store. + + This is the generic upload function that handles any content type. + Use this directly or through convenience wrappers like upload_csv_to_object_store. Args: storage: CloudStorage instance - csv_content: Raw CSV content as bytes - filename: Name of the file (can include timestamp) - subdirectory: Subdirectory path in object store (default: "datasets") + content: Raw content as bytes + filename: Name of the file + subdirectory: Subdirectory path in object store (e.g., "datasets", "stt_datasets") + content_type: MIME type of the content (default: "application/octet-stream") Returns: Object store URL as string if successful, None if failed @@ -42,47 +67,74 @@ def upload_csv_to_object_store( Callers should continue without object store URL when this returns None. """ logger.info( - f"[upload_csv_to_object_store] Preparing to upload '{filename}' | " - f"size={len(csv_content)} bytes, subdirectory='{subdirectory}'" + f"[upload_to_object_store] Preparing to upload '{filename}' | " + f"size={len(content)} bytes, subdirectory='{subdirectory}', " + f"content_type='{content_type}'" ) try: - # Create file path file_path = Path(subdirectory) / filename - # Create a mock UploadFile-like object for the storage put method - class CSVFile: - def __init__(self, content: bytes): - self.file = io.BytesIO(content) - self.content_type = "text/csv" - - csv_file = CSVFile(csv_content) + headers = Headers({"content-type": content_type}) + upload_file = UploadFile( + filename=filename, + file=BytesIO(content), + headers=headers, + ) - # Upload to object store - destination = storage.put(source=csv_file, file_path=file_path) + destination = storage.put(source=upload_file, file_path=file_path) object_store_url = str(destination) logger.info( - f"[upload_csv_to_object_store] Upload successful | " + f"[upload_to_object_store] Upload successful | " f"filename='{filename}', url='{object_store_url}'" ) return object_store_url except CloudStorageError as e: logger.warning( - f"[upload_csv_to_object_store] Upload failed for '{filename}': {e}. " + f"[upload_to_object_store] Upload failed for '{filename}': {e}. " "Continuing without object store storage." ) return None except Exception as e: logger.warning( - f"[upload_csv_to_object_store] Unexpected error uploading '{filename}': {e}. " + f"[upload_to_object_store] Unexpected error uploading '{filename}': {e}. " "Continuing without object store storage.", exc_info=True, ) return None +def upload_csv_to_object_store( + storage: CloudStorage, + csv_content: bytes, + filename: str, + subdirectory: str = "datasets", +) -> str | None: + """ + Upload CSV content to object store. + + Convenience wrapper around upload_to_object_store for CSV files. + + Args: + storage: CloudStorage instance + csv_content: Raw CSV content as bytes + filename: Name of the file (can include timestamp) + subdirectory: Subdirectory path in object store (default: "datasets") + + Returns: + Object store URL as string if successful, None if failed + """ + return upload_to_object_store( + storage=storage, + content=csv_content, + filename=filename, + subdirectory=subdirectory, + content_type="text/csv", + ) + + def upload_jsonl_to_object_store( storage: CloudStorage, results: list[dict], @@ -92,65 +144,28 @@ def upload_jsonl_to_object_store( """ Upload JSONL (JSON Lines) content to object store. + Convenience wrapper around upload_to_object_store for JSONL files. + Args: storage: CloudStorage instance results: List of dictionaries to be converted to JSONL filename: Name of the file - subdirectory: Subdirectory path in object store (e.g., "evaluation/batch-123") + subdirectory: Subdirectory path in object store Returns: Object store URL as string if successful, None if failed - - Note: - This function handles errors gracefully and returns None on failure. - Callers should continue without object store URL when this returns None. """ - logger.info( - f"[upload_jsonl_to_object_store] Preparing to upload '{filename}' | " - f"items={len(results)}, subdirectory='{subdirectory}'" + jsonl_content = "\n".join([json.dumps(result) for result in results]) + content_bytes = jsonl_content.encode("utf-8") + + return upload_to_object_store( + storage=storage, + content=content_bytes, + filename=filename, + subdirectory=subdirectory, + content_type="application/jsonl", ) - try: - # Create file path - file_path = Path(subdirectory) / filename - - # Convert results to JSONL - jsonl_content = "\n".join([json.dumps(result) for result in results]) - content_bytes = jsonl_content.encode("utf-8") - - # Create UploadFile-like object - headers = Headers({"content-type": "application/jsonl"}) - upload_file = UploadFile( - filename=filename, - file=BytesIO(content_bytes), - headers=headers, - ) - - # Upload to object store - destination = storage.put(source=upload_file, file_path=file_path) - object_store_url = str(destination) - - logger.info( - f"[upload_jsonl_to_object_store] Upload successful | " - f"filename='{filename}', url='{object_store_url}', " - f"size={len(content_bytes)} bytes" - ) - return object_store_url - - except CloudStorageError as e: - logger.warning( - f"[upload_jsonl_to_object_store] Upload failed for '{filename}': {e}. " - "Continuing without object store storage." - ) - return None - except Exception as e: - logger.warning( - f"[upload_jsonl_to_object_store] Unexpected error uploading '{filename}': {e}. " - "Continuing without object store storage.", - exc_info=True, - ) - return None - def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str: """ diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 658193106..9baa5defd 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -80,3 +80,9 @@ ) from .onboarding import onboard_project + +from .file import ( + create_file, + get_file_by_id, + get_files_by_ids, +) diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py index ca6bd2af2..fb9f7a66b 100644 --- a/backend/app/crud/evaluations/cron.py +++ b/backend/app/crud/evaluations/cron.py @@ -12,6 +12,7 @@ from sqlmodel import Session, select from app.crud.evaluations.processing import poll_all_pending_evaluations +from app.crud.stt_evaluations import poll_all_pending_stt_evaluations from app.models import Organization logger = logging.getLogger(__name__) @@ -84,22 +85,38 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: f"[process_all_pending_evaluations] Processing org_id={org.id} ({org.name})" ) - # Poll all pending evaluations for this org + # Poll all pending text evaluations for this org summary = await poll_all_pending_evaluations( session=session, org_id=org.id ) + # Poll all pending STT evaluations for this org + stt_summary = await poll_all_pending_stt_evaluations( + session=session, org_id=org.id + ) + + # Merge summaries + combined_summary = { + "text": summary, + "stt": stt_summary, + "processed": summary.get("processed", 0) + + stt_summary.get("processed", 0), + "failed": summary.get("failed", 0) + stt_summary.get("failed", 0), + "still_processing": summary.get("still_processing", 0) + + stt_summary.get("still_processing", 0), + } + results.append( { "org_id": org.id, "org_name": org.name, - "summary": summary, + "summary": combined_summary, } ) - total_processed += summary.get("processed", 0) - total_failed += summary.get("failed", 0) - total_still_processing += summary.get("still_processing", 0) + total_processed += combined_summary["processed"] + total_failed += combined_summary["failed"] + total_still_processing += combined_summary["still_processing"] except Exception as e: logger.error( diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index 99ce39d56..05ca94045 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -664,9 +664,11 @@ async def poll_all_pending_evaluations(session: Session, org_id: int) -> dict[st } """ # Get pending evaluations (status = "processing") + # Filter to only text evaluations - STT/TTS evaluations have their own polling statement = select(EvaluationRun).where( EvaluationRun.status == "processing", EvaluationRun.organization_id == org_id, + EvaluationRun.type == "text", ) pending_runs = session.exec(statement).all() diff --git a/backend/app/crud/file.py b/backend/app/crud/file.py new file mode 100644 index 000000000..5eebb8d32 --- /dev/null +++ b/backend/app/crud/file.py @@ -0,0 +1,124 @@ +"""CRUD operations for file records.""" + +import logging + +from sqlmodel import Session, select + +from app.core.util import now +from app.models.file import File, FileType + +logger = logging.getLogger(__name__) + + +def create_file( + *, + session: Session, + object_store_url: str, + filename: str, + size_bytes: int, + content_type: str, + file_type: str, + organization_id: int, + project_id: int, +) -> File: + """Create a new file record. + + Args: + session: Database session + object_store_url: S3 URL where the file is stored + filename: Original filename + size_bytes: File size in bytes + content_type: MIME type of the file + file_type: Type of file (audio, document, image, other) + organization_id: Organization ID + project_id: Project ID + + Returns: + File: Created file record + """ + logger.info( + f"[create_file] Creating file record | " + f"filename: {filename}, file_type: {file_type}, " + f"org_id: {organization_id}, project_id: {project_id}" + ) + + timestamp = now() + file = File( + object_store_url=object_store_url, + filename=filename, + size_bytes=size_bytes, + content_type=content_type, + file_type=file_type, + organization_id=organization_id, + project_id=project_id, + inserted_at=timestamp, + updated_at=timestamp, + ) + + session.add(file) + session.commit() + session.refresh(file) + + logger.info( + f"[create_file] File record created | " + f"file_id: {file.id}, filename: {filename}" + ) + + return file + + +def get_file_by_id( + *, + session: Session, + file_id: int, + organization_id: int, + project_id: int, +) -> File | None: + """Get a file record by ID. + + Args: + session: Database session + file_id: File ID + organization_id: Organization ID + project_id: Project ID + + Returns: + File | None: File record if found + """ + statement = select(File).where( + File.id == file_id, + File.organization_id == organization_id, + File.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def get_files_by_ids( + *, + session: Session, + file_ids: list[int], + organization_id: int, + project_id: int, +) -> list[File]: + """Get multiple file records by IDs. + + Args: + session: Database session + file_ids: List of file IDs + organization_id: Organization ID + project_id: Project ID + + Returns: + list[File]: List of file records found + """ + if not file_ids: + return [] + + statement = select(File).where( + File.id.in_(file_ids), + File.organization_id == organization_id, + File.project_id == project_id, + ) + + return list(session.exec(statement).all()) diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py new file mode 100644 index 000000000..71f0b3f3f --- /dev/null +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -0,0 +1,46 @@ +"""STT Evaluation CRUD operations.""" + +from .batch import start_stt_evaluation_batch +from .cron import poll_all_pending_stt_evaluations +from .dataset import ( + create_stt_dataset, + create_stt_samples, + get_stt_dataset_by_id, + list_stt_datasets, + get_samples_by_dataset_id, +) +from .run import ( + create_stt_run, + get_stt_run_by_id, + list_stt_runs, + update_stt_run, +) +from .result import ( + create_stt_results, + get_stt_result_by_id, + get_results_by_run_id, + update_human_feedback, +) + +__all__ = [ + # Batch + "start_stt_evaluation_batch", + # Cron + "poll_all_pending_stt_evaluations", + # Dataset + "create_stt_dataset", + "create_stt_samples", + "get_stt_dataset_by_id", + "list_stt_datasets", + "get_samples_by_dataset_id", + # Run + "create_stt_run", + "get_stt_run_by_id", + "list_stt_runs", + "update_stt_run", + # Result + "create_stt_results", + "get_stt_result_by_id", + "get_results_by_run_id", + "update_human_feedback", +] diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py new file mode 100644 index 000000000..e3b016335 --- /dev/null +++ b/backend/app/crud/stt_evaluations/batch.py @@ -0,0 +1,178 @@ +"""Batch submission functions for STT evaluation processing.""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.batch import ( + GeminiBatchProvider, + create_stt_batch_requests, + start_batch_job, +) +from app.core.cloud.storage import get_cloud_storage +from app.crud.file import get_files_by_ids +from app.crud.stt_evaluations.result import update_stt_result +from app.crud.stt_evaluations.run import update_stt_run +from app.models import EvaluationRun +from app.models.stt_evaluation import STTResultStatus, STTSample +from app.services.stt_evaluations.gemini import GeminiClient + +logger = logging.getLogger(__name__) + +DEFAULT_TRANSCRIPTION_PROMPT = ( + "Generate a verbatim transcript of the speech in this audio file. " + "Return only the transcription text without any formatting, timestamps, or metadata." +) + +DEFAULT_MODEL = "gemini-2.5-pro" + + +def start_stt_evaluation_batch( + *, + session: Session, + run: EvaluationRun, + samples: list[STTSample], + sample_to_result: dict[int, int], + org_id: int, + project_id: int, + signed_url_expires_in: int = 86400, +) -> dict[str, Any]: + """Generate signed URLs and submit Gemini batch job for STT evaluation. + + Args: + session: Database session + run: The evaluation run record + samples: List of STT samples to process + sample_to_result: Mapping of sample_id -> result_id for error handling + org_id: Organization ID + project_id: Project ID + signed_url_expires_in: Signed URL expiry in seconds (default: 24 hours) + + Returns: + dict: Result with batch job information + + Raises: + Exception: If batch submission fails + """ + logger.info( + f"[start_stt_evaluation_batch] Starting batch submission | " + f"run_id: {run.id}, sample_count: {len(samples)}" + ) + + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + + # Get cloud storage for S3 access + storage = get_cloud_storage(session=session, project_id=project_id) + + # Fetch file records to get object_store_url + file_ids = [sample.file_id for sample in samples] + file_records = get_files_by_ids( + session=session, + file_ids=file_ids, + organization_id=org_id, + project_id=project_id, + ) + file_map = {f.id: f for f in file_records} + + # Generate signed URLs for audio files + signed_urls: list[str] = [] + sample_keys: list[str] = [] + + for sample in samples: + try: + # Get object_store_url from file record + file_record = file_map.get(sample.file_id) + if not file_record: + raise ValueError(f"File record not found for file_id: {sample.file_id}") + + signed_url = storage.get_signed_url( + file_record.object_store_url, expires_in=signed_url_expires_in + ) + signed_urls.append(signed_url) + sample_keys.append(str(sample.id)) + + except Exception as e: + logger.error( + f"[start_stt_evaluation_batch] Failed to generate signed URL | " + f"sample_id: {sample.id}, error: {str(e)}" + ) + if sample.id in sample_to_result: + update_stt_result( + session=session, + result_id=sample_to_result[sample.id], + status=STTResultStatus.FAILED.value, + error_message=f"Failed to generate signed URL: {str(e)}", + ) + + if not signed_urls: + raise Exception("Failed to generate signed URLs for any audio files") + + jsonl_data = create_stt_batch_requests( + signed_urls=signed_urls, + prompt=DEFAULT_TRANSCRIPTION_PROMPT, + keys=sample_keys, + ) + + model = (run.providers or [DEFAULT_MODEL])[0] + model_path = f"models/{model}" + + batch_provider = GeminiBatchProvider(client=gemini_client.client, model=model_path) + + try: + batch_job = start_batch_job( + session=session, + provider=batch_provider, + provider_name="gemini", + job_type="stt_evaluation", + organization_id=org_id, + project_id=project_id, + jsonl_data=jsonl_data, + config={"model": model}, + ) + + logger.info( + f"[start_stt_evaluation_batch] Batch job created | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}" + ) + + except Exception as e: + logger.error( + f"[start_stt_evaluation_batch] Failed to submit batch | " + f"model: {model}, error: {str(e)}" + ) + for result_id in sample_to_result.values(): + update_stt_result( + session=session, + result_id=result_id, + status=STTResultStatus.FAILED.value, + error_message=f"Batch submission failed: {str(e)}", + ) + raise Exception(f"Batch submission failed: {str(e)}") + + # Link batch job to the evaluation run + update_stt_run( + session=session, + run_id=run.id, + status="processing", + batch_job_id=batch_job.id, + ) + + logger.info( + f"[start_stt_evaluation_batch] Batch submission complete | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}, " + f"sample_count: {len(signed_urls)}" + ) + + return { + "success": True, + "run_id": run.id, + "batch_job_id": batch_job.id, + "provider_batch_id": batch_job.provider_batch_id, + "sample_count": len(signed_urls), + } diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py new file mode 100644 index 000000000..923fcb322 --- /dev/null +++ b/backend/app/crud/stt_evaluations/cron.py @@ -0,0 +1,310 @@ +"""Cron processing functions for STT evaluations. + +This module provides functions that are called periodically to process +pending STT evaluations - polling batch status and processing completed batches. +""" + +import logging +from typing import Any + +from sqlmodel import Session, select + +from app.core.batch import BatchJobState, GeminiBatchProvider, poll_batch_status +from app.crud.job import get_batch_job +from app.crud.stt_evaluations.result import count_results_by_status, update_stt_result +from app.crud.stt_evaluations.run import get_pending_stt_runs, update_stt_run +from app.models import EvaluationRun +from app.models.stt_evaluation import STTResult, STTResultStatus +from app.services.stt_evaluations.gemini import GeminiClient + +logger = logging.getLogger(__name__) + +# Terminal states that indicate batch processing is complete +TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +async def poll_all_pending_stt_evaluations( + session: Session, + org_id: int, +) -> dict[str, Any]: + """Poll all pending STT evaluations for an organization. + + This function: + 1. Gets all STT runs in "processing" status + 2. For each run, polls the Gemini batch status + 3. If completed, processes the results + 4. Returns a summary of what was processed + + Args: + session: Database session + org_id: Organization ID + + Returns: + dict: Summary with processed, failed, still_processing counts + """ + logger.info( + f"[poll_all_pending_stt_evaluations] Starting STT evaluation polling | " + f"org_id: {org_id}" + ) + + # Get all pending STT runs for this organization + pending_runs = get_pending_stt_runs(session=session, org_id=org_id) + + if not pending_runs: + logger.info( + f"[poll_all_pending_stt_evaluations] No pending STT runs | org_id: {org_id}" + ) + return {"processed": 0, "failed": 0, "still_processing": 0} + + logger.info( + f"[poll_all_pending_stt_evaluations] Found {len(pending_runs)} pending STT runs | " + f"org_id: {org_id}" + ) + + processed = 0 + failed = 0 + still_processing = 0 + + for run in pending_runs: + try: + result = await poll_stt_run(session=session, run=run, org_id=org_id) + + if result["status"] == "completed": + processed += 1 + elif result["status"] == "failed": + failed += 1 + else: # still_processing + still_processing += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_stt_evaluations] Error polling run | " + f"run_id: {run.id}, error: {str(e)}", + exc_info=True, + ) + failed += 1 + + logger.info( + f"[poll_all_pending_stt_evaluations] Polling complete | " + f"org_id: {org_id}, processed: {processed}, failed: {failed}, " + f"still_processing: {still_processing}" + ) + + return { + "processed": processed, + "failed": failed, + "still_processing": still_processing, + } + + +async def poll_stt_run( + session: Session, + run: EvaluationRun, + org_id: int, +) -> dict[str, Any]: + """Poll a single STT evaluation run's batch status. + + Args: + session: Database session + run: The evaluation run to poll + org_id: Organization ID + + Returns: + dict: Status result with "status" key (completed/failed/still_processing) + """ + logger.info(f"[poll_stt_run] Polling run | run_id: {run.id}") + + # Check if run has batch_job_id + if not run.batch_job_id: + logger.warning(f"[poll_stt_run] Run has no batch_job_id | run_id: {run.id}") + return {"status": "failed", "error": "No batch job found"} + + # Get the batch job record + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if not batch_job: + logger.warning( + f"[poll_stt_run] BatchJob not found | run_id: {run.id}, " + f"batch_job_id: {run.batch_job_id}" + ) + return {"status": "failed", "error": "Batch job not found"} + + try: + # Initialize Gemini client and poll batch status + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=run.project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + # Poll and update batch job status + poll_batch_status( + session=session, + provider=batch_provider, + batch_job=batch_job, + ) + + # Refresh to get updated status + session.refresh(batch_job) + provider_status = batch_job.provider_status + + logger.info( + f"[poll_stt_run] Batch status | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}, " + f"state: {provider_status}" + ) + + # Check if batch is complete + if provider_status not in TERMINAL_STATES: + return {"status": "still_processing"} + + # Batch is complete - check if succeeded + if provider_status == BatchJobState.SUCCEEDED.value: + await process_completed_stt_batch( + session=session, + run=run, + batch_job=batch_job, + org_id=org_id, + ) + return {"status": "completed"} + else: + # Batch failed + error_msg = batch_job.error_message or f"Batch {provider_status}" + update_stt_run( + session=session, + run_id=run.id, + status="failed", + error_message=error_msg, + ) + return {"status": "failed", "error": error_msg} + + except Exception as e: + logger.error( + f"[poll_stt_run] Error polling run | run_id: {run.id}, error: {str(e)}", + exc_info=True, + ) + return {"status": "failed", "error": str(e)} + + +async def process_completed_stt_batch( + session: Session, + run: EvaluationRun, + batch_job: Any, + org_id: int, +) -> None: + """Process completed Gemini batch - download results and update STT result records. + + Args: + session: Database session + run: The evaluation run + batch_job: The BatchJob record + org_id: Organization ID + """ + logger.info( + f"[process_completed_stt_batch] Processing batch results | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}" + ) + + # Get the STT provider from batch job config + stt_provider = batch_job.config.get("stt_provider", "gemini-2.5-pro") + + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=run.project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + processed_count = 0 + failed_count = 0 + + try: + # Download results using GeminiBatchProvider + # Keys are embedded in the JSONL response file, no separate mapping needed + results = batch_provider.download_batch_results(batch_job.provider_batch_id) + + logger.info( + f"[process_completed_stt_batch] Got batch results | " + f"batch_job_id: {batch_job.id}, result_count: {len(results)}" + ) + + # Match results to samples using key (sample_id) from batch request + for batch_result in results: + custom_id = batch_result["custom_id"] + # custom_id is the sample_id as string (set via key in batch request) + try: + sample_id = int(custom_id) + except (ValueError, TypeError): + logger.warning( + f"[process_completed_stt_batch] Invalid custom_id | " + f"batch_job_id: {batch_job.id}, custom_id: {custom_id}" + ) + failed_count += 1 + continue + + # Find result record for this sample and provider + stmt = select(STTResult).where( + STTResult.evaluation_run_id == run.id, + STTResult.stt_sample_id == sample_id, + STTResult.provider == stt_provider, + ) + result_record = session.exec(stmt).one_or_none() + + if result_record: + if batch_result.get("response"): + text = batch_result["response"].get("text", "") + update_stt_result( + session=session, + result_id=result_record.id, + transcription=text, + status=STTResultStatus.COMPLETED.value, + ) + processed_count += 1 + else: + update_stt_result( + session=session, + result_id=result_record.id, + status=STTResultStatus.FAILED.value, + error_message=batch_result.get("error", "Unknown error"), + ) + failed_count += 1 + + except Exception as e: + logger.error( + f"[process_completed_stt_batch] Failed to process batch results | " + f"batch_job_id: {batch_job.id}, error: {str(e)}", + exc_info=True, + ) + # Mark all pending results as failed since we couldn't process the batch + raise + + # Update run status + status_counts = count_results_by_status(session=session, run_id=run.id) + + completed = status_counts.get(STTResultStatus.COMPLETED.value, 0) + failed = status_counts.get(STTResultStatus.FAILED.value, 0) + pending = status_counts.get(STTResultStatus.PENDING.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = None + if failed > 0: + error_message = f"{failed} transcription(s) failed" + + update_stt_run( + session=session, + run_id=run.id, + status=final_status, + error_message=error_message, + ) + + logger.info( + f"[process_completed_stt_batch] Batch results processed | " + f"run_id: {run.id}, completed: {completed}, " + f"failed: {failed}, status: {final_status}" + ) diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py new file mode 100644 index 000000000..cc43d8d00 --- /dev/null +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -0,0 +1,329 @@ +"""CRUD operations for STT evaluation datasets and samples.""" + +import logging +from typing import Any + +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select, func + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.crud.file import get_files_by_ids +from app.models import EvaluationDataset +from app.models.file import File +from app.models.stt_evaluation import ( + EvaluationType, + STTSample, + STTSampleCreate, + STTDatasetPublic, + STTSamplePublic, +) + +logger = logging.getLogger(__name__) + + +def create_stt_dataset( + *, + session: Session, + name: str, + org_id: int, + project_id: int, + description: str | None = None, + language_id: int | None = None, + object_store_url: str | None = None, + dataset_metadata: dict[str, Any] | None = None, +) -> EvaluationDataset: + """Create a new STT evaluation dataset. + + Args: + session: Database session + name: Dataset name + org_id: Organization ID + project_id: Project ID + description: Optional description + language_id: Optional reference to global.languages table + object_store_url: Optional object store URL + dataset_metadata: Optional metadata dict + + Returns: + EvaluationDataset: Created dataset + + Raises: + HTTPException: If dataset with same name already exists + """ + logger.info( + f"[create_stt_dataset] Creating STT dataset | " + f"name: {name}, org_id: {org_id}, project_id: {project_id}" + ) + + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.STT.value, + language_id=language_id, + object_store_url=object_store_url, + dataset_metadata=dataset_metadata or {}, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + try: + session.add(dataset) + session.commit() + session.refresh(dataset) + + logger.info( + f"[create_stt_dataset] STT dataset created | " + f"dataset_id: {dataset.id}, name: {name}" + ) + + return dataset + + except IntegrityError as e: + session.rollback() + if "uq_evaluation_dataset_name_org_project" in str(e): + logger.error( + f"[create_stt_dataset] Dataset name already exists | name: {name}" + ) + raise HTTPException( + status_code=400, + detail=f"Dataset with name '{name}' already exists", + ) + raise + + +def validate_file_ids( + *, + session: Session, + file_ids: list[int], + organization_id: int, + project_id: int, +) -> dict[int, File]: + """Validate that all file IDs exist and belong to the organization/project. + + Args: + session: Database session + file_ids: List of file IDs to validate + organization_id: Organization ID + project_id: Project ID + + Returns: + dict[int, File]: Mapping of file_id to File object + + Raises: + HTTPException: If any file IDs are invalid + """ + if not file_ids: + return {} + + files = get_files_by_ids( + session=session, + file_ids=file_ids, + organization_id=organization_id, + project_id=project_id, + ) + + file_map = {f.id: f for f in files} + missing_ids = set(file_ids) - set(file_map.keys()) + + if missing_ids: + raise HTTPException( + status_code=400, + detail=f"File IDs not found: {sorted(missing_ids)}", + ) + + return file_map + + +def create_stt_samples( + *, + session: Session, + dataset: EvaluationDataset, + samples: list[STTSampleCreate], +) -> list[STTSample]: + """Create STT samples for a dataset. + + Args: + session: Database session + dataset: Parent dataset (must have sample_count in dataset_metadata) + samples: List of sample data + + Returns: + list[STTSample]: Created samples + + Raises: + HTTPException: If any file IDs are invalid + """ + logger.info( + f"[create_stt_samples] Creating STT samples | " + f"dataset_id: {dataset.id}, sample_count: {len(samples)}" + ) + + # Validate all file IDs exist + file_ids = [sample.file_id for sample in samples] + file_map = validate_file_ids( + session=session, + file_ids=file_ids, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + ) + + timestamp = now() + created_samples = [ + STTSample( + file_id=sample_data.file_id, + ground_truth=sample_data.ground_truth, + language_id=dataset.language_id, + sample_metadata={ + "original_filename": file_map[sample_data.file_id].filename, + "file_extension": file_map[sample_data.file_id] + .filename.rsplit(".", 1)[-1] + .lower() + if "." in file_map[sample_data.file_id].filename + else None, + }, + dataset_id=dataset.id, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=timestamp, + updated_at=timestamp, + ) + for sample_data in samples + ] + + session.add_all(created_samples) + session.flush() + session.commit() + + logger.info( + f"[create_stt_samples] STT samples created | " + f"dataset_id: {dataset.id}, created_count: {len(created_samples)}" + ) + + return created_samples + + +def get_stt_dataset_by_id( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, +) -> EvaluationDataset | None: + """Get an STT dataset by ID. + + Args: + session: Database session + dataset_id: Dataset ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationDataset | None: Dataset if found + """ + statement = select(EvaluationDataset).where( + EvaluationDataset.id == dataset_id, + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.STT.value, + ) + + return session.exec(statement).one_or_none() + + +def list_stt_datasets( + *, + session: Session, + org_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> tuple[list[STTDatasetPublic], int]: + """List STT datasets for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTDatasetPublic], int]: Datasets and total count + """ + base_filter = ( + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.STT.value, + ) + + count_stmt = select(func.count(EvaluationDataset.id)).where(*base_filter) + total = session.exec(count_stmt).one() + + statement = ( + select(EvaluationDataset) + .where(*base_filter) + .order_by(EvaluationDataset.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + datasets = session.exec(statement).all() + + result = [ + STTDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language_id=dataset.language_id, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + for dataset in datasets + ] + + return result, total + + +def get_samples_by_dataset_id( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, + limit: int = 100, + offset: int = 0, +) -> list[STTSample]: + """Get samples for a dataset. + + Args: + session: Database session + dataset_id: Dataset ID + org_id: Organization ID + project_id: Project ID + limit: Maximum results to return + offset: Number of results to skip + + Returns: + list[STTSample]: Samples + """ + statement = ( + select(STTSample) + .where( + STTSample.dataset_id == dataset_id, + STTSample.organization_id == org_id, + STTSample.project_id == project_id, + ) + .order_by(STTSample.id) + .offset(offset) + .limit(limit) + ) + + return list(session.exec(statement).all()) diff --git a/backend/app/crud/stt_evaluations/result.py b/backend/app/crud/stt_evaluations/result.py new file mode 100644 index 000000000..73f6d5be5 --- /dev/null +++ b/backend/app/crud/stt_evaluations/result.py @@ -0,0 +1,353 @@ +"""CRUD operations for STT evaluation results.""" + +import logging +from typing import Any + +from sqlmodel import Session, select, func + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models.file import File +from app.models.stt_evaluation import ( + STTResult, + STTResultStatus, + STTSample, + STTSamplePublic, + STTResultWithSample, +) + +logger = logging.getLogger(__name__) + + +def create_stt_results( + *, + session: Session, + samples: list[STTSample], + evaluation_run_id: int, + org_id: int, + project_id: int, + providers: list[str], +) -> list[STTResult]: + """Create STT result records for all samples and providers. + + Creates one result per sample per provider. + + Args: + session: Database session + samples: List of samples + evaluation_run_id: Run ID + org_id: Organization ID + project_id: Project ID + providers: List of providers + + Returns: + list[STTResult]: Created results + """ + logger.info( + f"[create_stt_results] Creating STT results | " + f"run_id: {evaluation_run_id}, sample_count: {len(samples)}, " + f"provider_count: {len(providers)}" + ) + + timestamp = now() + results = [ + STTResult( + stt_sample_id=sample.id, + evaluation_run_id=evaluation_run_id, + organization_id=org_id, + project_id=project_id, + provider=provider, + status=STTResultStatus.PENDING.value, + inserted_at=timestamp, + updated_at=timestamp, + ) + for sample in samples + for provider in providers + ] + + session.add_all(results) + session.flush() + session.commit() + + logger.info( + f"[create_stt_results] STT results created | " + f"run_id: {evaluation_run_id}, result_count: {len(results)}" + ) + + return results + + +def get_stt_result_by_id( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, +) -> STTResult | None: + """Get an STT result by ID. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + + Returns: + STTResult | None: Result if found + """ + statement = select(STTResult).where( + STTResult.id == result_id, + STTResult.organization_id == org_id, + STTResult.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def get_results_by_run_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, + provider: str | None = None, + status: str | None = None, + limit: int = 100, + offset: int = 0, +) -> tuple[list[STTResultWithSample], int]: + """Get results for an evaluation run with sample data. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + provider: Optional filter by provider + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTResultWithSample], int]: Results with samples and total count + """ + where_clauses = [ + STTResult.evaluation_run_id == run_id, + STTResult.organization_id == org_id, + STTResult.project_id == project_id, + ] + + if provider is not None: + where_clauses.append(STTResult.provider == provider) + + if status is not None: + where_clauses.append(STTResult.status == status) + + count_stmt = select(func.count(STTResult.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + statement = ( + select(STTResult, STTSample, File) + .join(STTSample, STTResult.stt_sample_id == STTSample.id) + .join(File, STTSample.file_id == File.id) + .where(*where_clauses) + .order_by(STTResult.id) + .offset(offset) + .limit(limit) + ) + + rows = session.exec(statement).all() + + # Convert to response models + results = [] + for result, sample, file in rows: + sample_public = STTSamplePublic( + id=sample.id, + file_id=sample.file_id, + object_store_url=file.object_store_url, + language_id=sample.language_id, + ground_truth=sample.ground_truth, + sample_metadata=sample.sample_metadata, + dataset_id=sample.dataset_id, + organization_id=sample.organization_id, + project_id=sample.project_id, + inserted_at=sample.inserted_at, + updated_at=sample.updated_at, + ) + + result_with_sample = STTResultWithSample( + id=result.id, + transcription=result.transcription, + provider=result.provider, + status=result.status, + score=result.score, + is_correct=result.is_correct, + comment=result.comment, + error_message=result.error_message, + stt_sample_id=result.stt_sample_id, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + sample=sample_public, + ) + results.append(result_with_sample) + + return results, total + + +def update_stt_result( + *, + session: Session, + result_id: int, + transcription: str | None = None, + status: str | None = None, + score: dict[str, Any] | None = None, + error_message: str | None = None, +) -> STTResult | None: + """Update an STT result with transcription data. + + Args: + session: Database session + result_id: Result ID + transcription: Generated transcription + status: New status + score: Evaluation metrics (e.g., wer, cer) + error_message: Error message if failed + + Returns: + STTResult | None: Updated result + """ + statement = select(STTResult).where(STTResult.id == result_id) + result = session.exec(statement).one_or_none() + + if not result: + return None + + updates = { + "transcription": transcription, + "status": status, + "score": score, + "error_message": error_message, + } + + for field, value in updates.items(): + if value is not None: + setattr(result, field, value) + + result.updated_at = now() + + session.add(result) + session.commit() + session.refresh(result) + + return result + + +def update_human_feedback( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, + is_correct: bool | None = None, + comment: str | None = None, +) -> STTResult | None: + """Update human feedback on an STT result. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + is_correct: Human verification of correctness + comment: Feedback comment + + Returns: + STTResult | None: Updated result + + Raises: + HTTPException: If result not found + """ + result = get_stt_result_by_id( + session=session, + result_id=result_id, + org_id=org_id, + project_id=project_id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + if is_correct is not None: + result.is_correct = is_correct + + if comment is not None: + result.comment = comment + + result.updated_at = now() + + session.add(result) + session.commit() + session.refresh(result) + + logger.info( + f"[update_human_feedback] Human feedback updated | " + f"result_id: {result_id}, is_correct: {is_correct}" + ) + + return result + + +def get_pending_results_for_run( + *, + session: Session, + run_id: int, + provider: str | None = None, +) -> list[STTResult]: + """Get all pending results for a run. + + Args: + session: Database session + run_id: Run ID + provider: Optional filter by provider + + Returns: + list[STTResult]: Pending results + """ + where_clauses = [ + STTResult.evaluation_run_id == run_id, + STTResult.status == STTResultStatus.PENDING.value, + ] + + if provider is not None: + where_clauses.append(STTResult.provider == provider) + + statement = select(STTResult).where(*where_clauses) + + return list(session.exec(statement).all()) + + +def count_results_by_status( + *, + session: Session, + run_id: int, +) -> dict[str, int]: + """Count results by status for a run. + + Args: + session: Database session + run_id: Run ID + + Returns: + dict[str, int]: Counts by status + """ + statement = ( + select(STTResult.status, func.count(STTResult.id)) + .where(STTResult.evaluation_run_id == run_id) + .group_by(STTResult.status) + ) + + rows = session.exec(statement).all() + + return {status: count for status, count in rows} diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py new file mode 100644 index 000000000..6f9729968 --- /dev/null +++ b/backend/app/crud/stt_evaluations/run.py @@ -0,0 +1,259 @@ +"""CRUD operations for STT evaluation runs.""" + +import logging +from typing import Any + +from sqlmodel import Session, select, func + +from app.core.util import now +from app.models import EvaluationDataset, EvaluationRun +from app.models.stt_evaluation import ( + EvaluationType, + STTEvaluationRunPublic, +) + +logger = logging.getLogger(__name__) + + +def create_stt_run( + *, + session: Session, + run_name: str, + dataset_id: int, + dataset_name: str, + org_id: int, + project_id: int, + providers: list[str], + language_id: int | None = None, + total_items: int = 0, +) -> EvaluationRun: + """Create a new STT evaluation run. + + Args: + session: Database session + run_name: Name for the run + dataset_id: ID of the dataset to evaluate + dataset_name: Name of the dataset + org_id: Organization ID + project_id: Project ID + providers: List of STT providers to use + language_id: Optional language ID override (references global.languages) + total_items: Total number of items to process + + Returns: + EvaluationRun: Created run + """ + logger.info( + f"[create_stt_run] Creating STT evaluation run | " + f"run_name: {run_name}, dataset_id: {dataset_id}, " + f"providers: {providers}" + ) + + run = EvaluationRun( + run_name=run_name, + dataset_name=dataset_name, + dataset_id=dataset_id, + type=EvaluationType.STT.value, + language_id=language_id, + providers=providers, + status="pending", + total_items=total_items, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[create_stt_run] STT evaluation run created | " + f"run_id: {run.id}, run_name: {run_name}" + ) + + return run + + +def get_stt_run_by_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, +) -> EvaluationRun | None: + """Get an STT evaluation run by ID. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationRun | None: Run if found + """ + statement = select(EvaluationRun).where( + EvaluationRun.id == run_id, + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.STT.value, + ) + + return session.exec(statement).one_or_none() + + +def list_stt_runs( + *, + session: Session, + org_id: int, + project_id: int, + dataset_id: int | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, +) -> tuple[list[STTEvaluationRunPublic], int]: + """List STT evaluation runs for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + dataset_id: Optional filter by dataset + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTEvaluationRunPublic], int]: Runs and total count + """ + where_clauses = [ + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.STT.value, + ] + + if dataset_id is not None: + where_clauses.append(EvaluationRun.dataset_id == dataset_id) + + if status is not None: + where_clauses.append(EvaluationRun.status == status) + + count_stmt = select(func.count(EvaluationRun.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + statement = ( + select(EvaluationRun) + .where(*where_clauses) + .order_by(EvaluationRun.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + runs = session.exec(statement).all() + + result = [ + STTEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language_id=run.language_id, + providers=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + for run in runs + ] + + return result, total + + +def update_stt_run( + *, + session: Session, + run_id: int, + status: str | None = None, + score: dict[str, Any] | None = None, + error_message: str | None = None, + object_store_url: str | None = None, + batch_job_id: int | None = None, +) -> EvaluationRun | None: + """Update an STT evaluation run. + + Args: + session: Database session + run_id: Run ID + status: New status + score: Score data + error_message: Error message + object_store_url: URL to stored results + batch_job_id: ID of the associated batch job + + Returns: + EvaluationRun | None: Updated run + """ + statement = select(EvaluationRun).where(EvaluationRun.id == run_id) + run = session.exec(statement).one_or_none() + + if not run: + return None + + updates = { + "status": status, + "score": score, + "error_message": error_message, + "object_store_url": object_store_url, + "batch_job_id": batch_job_id, + } + + for field, value in updates.items(): + if value is not None: + setattr(run, field, value) + + run.updated_at = now() + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[update_stt_run] STT run updated | run_id: {run_id}, status: {run.status}" + ) + + return run + + +def get_pending_stt_runs( + *, + session: Session, + org_id: int, +) -> list[EvaluationRun]: + """Get all pending STT evaluation runs that are ready for polling. + + Only returns runs with status "processing" that have a batch_job_id set, + meaning the batch has been submitted and is ready to be polled. + + Args: + session: Database session + org_id: Organization ID + + Returns: + list[EvaluationRun]: Pending runs ready for polling + """ + statement = select(EvaluationRun).where( + EvaluationRun.type == EvaluationType.STT.value, + EvaluationRun.status == "processing", + EvaluationRun.batch_job_id.is_not(None), + EvaluationRun.organization_id == org_id, + ) + + return list(session.exec(statement).all()) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ca51576bd..12e6fefcd 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -78,6 +78,8 @@ EvaluationRunPublic, ) +from .file import File, FilePublic, FileType + from .fine_tuning import ( FineTuningJobBase, Fine_Tuning, diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 6ae4542fb..e6c50b912 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -101,6 +101,19 @@ class EvaluationDataset(SQLModel, table=True): description="Optional description of the dataset", sa_column_kwargs={"comment": "Description of the dataset"}, ) + type: str = SQLField( + default="text", + max_length=20, + description="Evaluation type: text, stt, or tts", + sa_column_kwargs={"comment": "Evaluation type: text, stt, or tts"}, + ) + language_id: int | None = SQLField( + default=None, + foreign_key="global.languages.id", + nullable=True, + description="Reference to the language in the global languages table", + sa_column_kwargs={"comment": "Foreign key to global.languages table"}, + ) # Dataset metadata stored as JSONB dataset_metadata: dict[str, Any] = SQLField( @@ -193,6 +206,28 @@ class EvaluationRun(SQLModel, table=True): description="Name of the Langfuse dataset", sa_column_kwargs={"comment": "Name of the Langfuse dataset used"}, ) + type: str = SQLField( + default="text", + max_length=20, + description="Evaluation type: text, stt, or tts", + sa_column_kwargs={"comment": "Evaluation type: text, stt, or tts"}, + ) + language_id: int | None = SQLField( + default=None, + foreign_key="global.languages.id", + nullable=True, + description="Reference to the language in the global languages table", + sa_column_kwargs={"comment": "Foreign key to global.languages table"}, + ) + providers: list[str] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="List of STT/TTS providers used (e.g., ['gemini-2.5-pro'])", + ), + description="List of STT/TTS providers used", + ) config_id: UUID = SQLField( foreign_key="config.id", diff --git a/backend/app/models/file.py b/backend/app/models/file.py new file mode 100644 index 000000000..54b06600e --- /dev/null +++ b/backend/app/models/file.py @@ -0,0 +1,99 @@ +"""File model for storing uploaded files metadata.""" + +from datetime import datetime +from enum import Enum + +from pydantic import BaseModel +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel + +from app.core.util import now + + +class FileType(str, Enum): + """Type of file stored.""" + + AUDIO = "audio" + DOCUMENT = "document" + IMAGE = "image" + OTHER = "other" + + +class File(SQLModel, table=True): + """Database table for storing uploaded file metadata.""" + + __tablename__ = "file" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the file"}, + ) + + object_store_url: str = SQLField( + description="S3 URL of the file", + sa_column_kwargs={"comment": "S3 URL where the file is stored"}, + ) + + filename: str = SQLField( + max_length=255, + description="Original filename", + sa_column_kwargs={"comment": "Original filename as uploaded"}, + ) + + size_bytes: int = SQLField( + description="File size in bytes", + sa_column_kwargs={"comment": "File size in bytes"}, + ) + + content_type: str = SQLField( + max_length=100, + description="MIME type of the file", + sa_column_kwargs={"comment": "MIME type of the file (e.g., audio/mp3)"}, + ) + + file_type: str = SQLField( + default=FileType.OTHER.value, + max_length=20, + description="Type of file: audio, document, image, other", + sa_column_kwargs={"comment": "Type of file: audio, document, image, other"}, + ) + + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the file was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the file was last updated"}, + ) + + +class FilePublic(BaseModel): + """Public model for file responses.""" + + id: int + object_store_url: str + filename: str + size_bytes: int + content_type: str + file_type: str + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py new file mode 100644 index 000000000..c43655b9c --- /dev/null +++ b/backend/app/models/stt_evaluation.py @@ -0,0 +1,378 @@ +"""STT Evaluation models for Speech-to-Text evaluation feature.""" + +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field, field_validator +from sqlalchemy import Column, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel + +from app.core.util import now + +# Supported STT providers for evaluation +SUPPORTED_STT_PROVIDERS = ["gemini-2.5-pro"] + + +class EvaluationType(str, Enum): + """Type of evaluation dataset/run.""" + + TEXT = "text" + STT = "stt" + TTS = "tts" + + +class STTResultStatus(str, Enum): + """Status of an STT result.""" + + PENDING = "pending" + COMPLETED = "completed" + FAILED = "failed" + + +class STTSample(SQLModel, table=True): + """Database table for STT audio samples within a dataset.""" + + __tablename__ = "stt_sample" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the STT sample"}, + ) + + file_id: int = SQLField( + foreign_key="file.id", + nullable=False, + ondelete="CASCADE", + description="Reference to the uploaded audio file", + sa_column_kwargs={ + "comment": "Reference to the uploaded audio file in file table" + }, + ) + + language_id: int | None = SQLField( + default=None, + foreign_key="global.languages.id", + nullable=True, + description="Reference to the language in the global languages table", + sa_column_kwargs={"comment": "Foreign key to global.languages table"}, + ) + + ground_truth: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Reference transcription for comparison (optional)", + ), + description="Reference transcription for comparison", + ) + + sample_metadata: dict[str, Any] | None = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=True, + comment="Additional metadata (format, bitrate, original filename, etc.)", + ), + description="Additional metadata about the audio sample", + ) + + dataset_id: int = SQLField( + foreign_key="evaluation_dataset.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the parent evaluation dataset"}, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the sample was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the sample was last updated"}, + ) + + +class STTResult(SQLModel, table=True): + """Database table for STT transcription results.""" + + __tablename__ = "stt_result" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the STT result"}, + ) + + transcription: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Generated transcription from STT provider", + ), + description="Generated transcription from STT provider", + ) + + provider: str = SQLField( + max_length=50, + description="STT provider used (e.g., gemini-2.5-pro)", + sa_column_kwargs={"comment": "STT provider used (e.g., gemini-2.5-pro)"}, + ) + + status: str = SQLField( + default=STTResultStatus.PENDING.value, + max_length=20, + description="Result status: pending, completed, failed", + sa_column_kwargs={"comment": "Result status: pending, completed, failed"}, + ) + + score: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Evaluation metrics (e.g., wer, cer, mer, wil) - extensible for future metrics", + ), + description="Evaluation metrics such as WER, CER, etc.", + ) + + is_correct: bool | None = SQLField( + default=None, + description="Human feedback: transcription correctness", + sa_column_kwargs={ + "comment": "Human feedback: transcription correctness (null=not reviewed)" + }, + ) + comment: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Human feedback comment", + ), + description="Human feedback comment", + ) + + error_message: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Error message if transcription failed", + ), + description="Error message if transcription failed", + ) + + stt_sample_id: int = SQLField( + foreign_key="stt_sample.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the STT sample"}, + ) + evaluation_run_id: int = SQLField( + foreign_key="evaluation_run.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the evaluation run"}, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was last updated"}, + ) + + +class STTSampleCreate(BaseModel): + """Request model for creating an STT sample.""" + + file_id: int = Field(..., description="ID of the uploaded audio file") + ground_truth: str | None = Field( + None, description="Reference transcription (optional)" + ) + + +class STTSamplePublic(BaseModel): + """Public model for STT samples.""" + + id: int + file_id: int + object_store_url: str | None = None # Populated from file record when needed + language_id: int | None + ground_truth: str | None + sample_metadata: dict[str, Any] | None + dataset_id: int + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTResultPublic(BaseModel): + """Public model for STT results.""" + + id: int + transcription: str | None + provider: str + status: str + score: dict[str, Any] | None + is_correct: bool | None + comment: str | None + error_message: str | None + stt_sample_id: int + evaluation_run_id: int + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTResultWithSample(STTResultPublic): + """STT result with embedded sample data.""" + + sample: STTSamplePublic + + +class STTFeedbackUpdate(BaseModel): + """Request model for updating human feedback on a result.""" + + is_correct: bool | None = Field(None, description="Is the transcription correct?") + comment: str | None = Field(None, description="Feedback comment") + + +class STTDatasetCreate(BaseModel): + """Request model for creating an STT dataset.""" + + name: str = Field(..., description="Dataset name", min_length=1) + description: str | None = Field(None, description="Dataset description") + language_id: int | None = Field( + None, description="ID of the language from global languages table" + ) + samples: list[STTSampleCreate] = Field( + ..., description="List of audio samples", min_length=1 + ) + + +class STTDatasetPublic(BaseModel): + """Public model for STT datasets.""" + + id: int + name: str + description: str | None + type: str + language_id: int | None + object_store_url: str | None + dataset_metadata: dict[str, Any] + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTDatasetWithSamples(STTDatasetPublic): + """STT dataset with embedded samples.""" + + samples: list[STTSamplePublic] + + +class STTEvaluationRunCreate(BaseModel): + """Request model for starting an STT evaluation run.""" + + run_name: str = Field(..., description="Name for this evaluation run", min_length=1) + dataset_id: int = Field(..., description="ID of the STT dataset to evaluate") + providers: list[str] = Field( + default_factory=lambda: ["gemini-2.5-pro"], + description="List of STT providers to use", + min_length=1, + ) + language_id: int | None = Field( + None, description="Override language ID from global languages table" + ) + + @field_validator("providers") + @classmethod + def validate_providers(cls, v: list[str]) -> list[str]: + """Validate that all providers are supported.""" + if not v: + raise ValueError("At least one provider must be specified") + unsupported = [p for p in v if p not in SUPPORTED_STT_PROVIDERS] + if unsupported: + raise ValueError( + f"Unsupported provider(s): {', '.join(unsupported)}. " + f"Supported providers are: {', '.join(SUPPORTED_STT_PROVIDERS)}" + ) + return v + + +class STTEvaluationRunPublic(BaseModel): + """Public model for STT evaluation runs.""" + + id: int + run_name: str + dataset_name: str + type: str + language_id: int | None + providers: list[str] | None + dataset_id: int + status: str + total_items: int + score: dict[str, Any] | None + error_message: str | None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTEvaluationRunWithResults(STTEvaluationRunPublic): + """STT evaluation run with embedded results.""" + + results: list[STTResultWithSample] + results_total: int = Field(0, description="Total number of results") + + +class AudioUploadResponse(BaseModel): + """Response model for audio file upload.""" + + file_id: int = Field(..., description="ID of the created file record") + s3_url: str = Field(..., description="S3 URL of the uploaded audio file") + filename: str = Field(..., description="Original filename") + size_bytes: int = Field(..., description="File size in bytes") + content_type: str = Field(..., description="MIME type of the audio file") diff --git a/backend/app/services/stt_evaluations/__init__.py b/backend/app/services/stt_evaluations/__init__.py new file mode 100644 index 000000000..0e22a1c2e --- /dev/null +++ b/backend/app/services/stt_evaluations/__init__.py @@ -0,0 +1 @@ +"""STT Evaluation services.""" diff --git a/backend/app/services/stt_evaluations/audio.py b/backend/app/services/stt_evaluations/audio.py new file mode 100644 index 000000000..e0ab135ee --- /dev/null +++ b/backend/app/services/stt_evaluations/audio.py @@ -0,0 +1,134 @@ +"""Audio file upload service for STT evaluation.""" + +import logging +import uuid +from pathlib import Path + +from fastapi import UploadFile +from sqlmodel import Session + +from app.core.cloud.storage import get_cloud_storage +from app.core.exception_handlers import HTTPException +from app.crud.file import create_file +from app.models.file import FileType +from app.models.stt_evaluation import AudioUploadResponse +from app.services.stt_evaluations.constants import ( + MAX_FILE_SIZE_BYTES, + MIME_TO_EXTENSION, + SUPPORTED_AUDIO_FORMATS, +) + +logger = logging.getLogger(__name__) + + +def _resolve_extension(file: UploadFile) -> str | None: + """Get audio file extension from filename, falling back to content type.""" + if file.filename and "." in file.filename: + return file.filename.rsplit(".", 1)[-1].lower() + if file.content_type: + return MIME_TO_EXTENSION.get(file.content_type.lower()) + return None + + +def _validate_audio_file(file: UploadFile) -> str: + """Validate an uploaded audio file and return its extension. + + Raises: + HTTPException: If file is invalid + """ + if not file.filename: + raise HTTPException(status_code=400, detail="Filename is required") + + extension = _resolve_extension(file) + + if not extension or extension not in SUPPORTED_AUDIO_FORMATS: + supported = ", ".join(sorted(SUPPORTED_AUDIO_FORMATS)) + raise HTTPException( + status_code=400, + detail=( + f"Unsupported audio format: {extension or 'unknown'}. " + f"Supported formats: {supported}" + ), + ) + + if file.size and file.size > MAX_FILE_SIZE_BYTES: + max_mb = MAX_FILE_SIZE_BYTES / (1024 * 1024) + file_mb = file.size / (1024 * 1024) + raise HTTPException( + status_code=400, + detail=f"File too large: {file_mb:.2f} MB. Maximum size: {max_mb:.0f} MB", + ) + + return extension + + +def upload_audio_file( + session: Session, + file: UploadFile, + organization_id: int, + project_id: int, +) -> AudioUploadResponse: + """Upload an audio file to S3 and create a file record. + + Raises: + HTTPException: If validation or upload fails + """ + logger.info( + f"[upload_audio_file] Starting audio upload | " + f"project_id: {project_id}, filename: {file.filename}" + ) + + extension = _validate_audio_file(file) + + file_uuid = uuid.uuid4() + new_filename = f"{file_uuid}.{extension}" + file_path = Path("stt") / "audio" / new_filename + + try: + storage = get_cloud_storage(session=session, project_id=project_id) + s3_url = str(storage.put(source=file, file_path=file_path)) + + try: + size_bytes = int(storage.get_file_size_kb(s3_url) * 1024) + except Exception: + size_bytes = file.size or 0 + + original_filename = file.filename or new_filename + content_type = file.content_type or f"audio/{extension}" + + file_record = create_file( + session=session, + object_store_url=s3_url, + filename=original_filename, + size_bytes=size_bytes, + content_type=content_type, + file_type=FileType.AUDIO.value, + organization_id=organization_id, + project_id=project_id, + ) + + logger.info( + f"[upload_audio_file] Audio uploaded successfully | " + f"project_id: {project_id}, file_id: {file_record.id}, " + f"s3_url: {s3_url}, size_bytes: {size_bytes}" + ) + + return AudioUploadResponse( + file_id=file_record.id, + s3_url=s3_url, + filename=original_filename, + size_bytes=size_bytes, + content_type=content_type, + ) + + except HTTPException: + raise + except Exception as e: + logger.error( + f"[upload_audio_file] Failed to upload audio | " + f"project_id: {project_id}, error: {str(e)}" + ) + raise HTTPException( + status_code=500, + detail="Failed to upload audio file. Please try again later.", + ) diff --git a/backend/app/services/stt_evaluations/constants.py b/backend/app/services/stt_evaluations/constants.py new file mode 100644 index 000000000..41da0b5e7 --- /dev/null +++ b/backend/app/services/stt_evaluations/constants.py @@ -0,0 +1,32 @@ +"""Shared constants for STT evaluation services.""" + +# Supported audio formats for STT evaluation +SUPPORTED_AUDIO_FORMATS: set[str] = {"mp3", "wav", "flac", "m4a", "ogg", "webm"} + +# Maximum audio file size (200 MB) +MAX_FILE_SIZE_BYTES: int = 200 * 1024 * 1024 + +# Mapping from file extension to MIME type +EXTENSION_TO_MIME: dict[str, str] = { + "mp3": "audio/mp3", + "wav": "audio/wav", + "flac": "audio/flac", + "m4a": "audio/mp4", + "ogg": "audio/ogg", + "webm": "audio/webm", +} + +# Mapping from MIME type to file extension +MIME_TO_EXTENSION: dict[str, str] = { + "audio/mp3": "mp3", + "audio/mpeg": "mp3", + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/wave": "wav", + "audio/flac": "flac", + "audio/x-flac": "flac", + "audio/mp4": "m4a", + "audio/x-m4a": "m4a", + "audio/ogg": "ogg", + "audio/webm": "webm", +} diff --git a/backend/app/services/stt_evaluations/dataset.py b/backend/app/services/stt_evaluations/dataset.py new file mode 100644 index 000000000..f4c4ae86a --- /dev/null +++ b/backend/app/services/stt_evaluations/dataset.py @@ -0,0 +1,175 @@ +"""Dataset management service for STT evaluations.""" + +import csv +import io +import logging + +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import ( + generate_timestamped_filename, + upload_csv_to_object_store, +) +from app.crud.stt_evaluations import ( + create_stt_dataset, + create_stt_samples, +) +from app.models import EvaluationDataset +from app.models.stt_evaluation import STTSample, STTSampleCreate + +logger = logging.getLogger(__name__) + + +def upload_stt_dataset( + session: Session, + name: str, + samples: list[STTSampleCreate], + organization_id: int, + project_id: int, + description: str | None = None, + language_id: int | None = None, +) -> tuple[EvaluationDataset, list[STTSample]]: + """ + Orchestrate STT dataset upload workflow. + + Steps: + 1. Convert samples to CSV format + 2. Upload CSV to object store + 3. Create dataset record in database + 4. Create sample records in database + + Args: + session: Database session + name: Dataset name + samples: List of STT samples to create + organization_id: Organization ID + project_id: Project ID + description: Optional dataset description + language_id: Optional reference to global.languages table + + Returns: + Tuple of (created dataset, created samples) + """ + logger.info( + f"[upload_stt_dataset] Uploading STT dataset | name={name} | " + f"sample_count={len(samples)} | org_id={organization_id} | " + f"project_id={project_id}" + ) + + # Step 1: Convert samples to CSV and upload to object store + object_store_url = _upload_samples_to_object_store( + session=session, + project_id=project_id, + dataset_name=name, + samples=samples, + ) + + # Step 2: Calculate metadata + metadata = { + "sample_count": len(samples), + "has_ground_truth_count": sum(1 for s in samples if s.ground_truth), + } + + # Step 3: Create dataset record + dataset = create_stt_dataset( + session=session, + name=name, + org_id=organization_id, + project_id=project_id, + description=description, + language_id=language_id, + object_store_url=object_store_url, + dataset_metadata=metadata, + ) + + logger.info( + f"[upload_stt_dataset] Created dataset record | " + f"id={dataset.id} | name={name}" + ) + + # Step 4: Create sample records + created_samples = create_stt_samples( + session=session, + dataset=dataset, + samples=samples, + ) + + logger.info( + f"[upload_stt_dataset] Created sample records | " + f"dataset_id={dataset.id} | sample_count={len(created_samples)}" + ) + + return dataset, created_samples + + +def _upload_samples_to_object_store( + session: Session, + project_id: int, + dataset_name: str, + samples: list[STTSampleCreate], +) -> str | None: + """ + Upload STT samples as CSV to object store. + + Args: + session: Database session + project_id: Project ID for storage credentials + dataset_name: Dataset name for filename + samples: List of samples to upload + + Returns: + Object store URL if successful, None otherwise + """ + try: + storage = get_cloud_storage(session=session, project_id=project_id) + + # Convert samples to CSV format + csv_content = _samples_to_csv(samples) + + # Generate filename and upload + filename = generate_timestamped_filename(dataset_name, "csv") + object_store_url = upload_csv_to_object_store( + storage=storage, + csv_content=csv_content, + filename=filename, + subdirectory="stt_datasets", + ) + + if object_store_url: + logger.info( + f"[_upload_samples_to_object_store] Upload successful | " + f"url={object_store_url}" + ) + else: + logger.info( + "[_upload_samples_to_object_store] Upload returned None | " + "continuing without object store storage" + ) + + return object_store_url + + except Exception as e: + logger.warning( + f"[_upload_samples_to_object_store] Failed to upload | {e}", + exc_info=True, + ) + return None + + +def _samples_to_csv(samples: list[STTSampleCreate]) -> bytes: + """ + Convert STT samples to CSV format. + + Args: + samples: List of samples + + Returns: + CSV content as bytes + """ + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(["file_id", "ground_truth"]) + for sample in samples: + writer.writerow([sample.file_id, sample.ground_truth or ""]) + return output.getvalue().encode("utf-8") diff --git a/backend/app/services/stt_evaluations/gemini/__init__.py b/backend/app/services/stt_evaluations/gemini/__init__.py new file mode 100644 index 000000000..07f7037cf --- /dev/null +++ b/backend/app/services/stt_evaluations/gemini/__init__.py @@ -0,0 +1,5 @@ +"""Gemini integration for STT evaluation.""" + +from .client import GeminiClient + +__all__ = ["GeminiClient"] diff --git a/backend/app/services/stt_evaluations/gemini/client.py b/backend/app/services/stt_evaluations/gemini/client.py new file mode 100644 index 000000000..1ee52d66b --- /dev/null +++ b/backend/app/services/stt_evaluations/gemini/client.py @@ -0,0 +1,113 @@ +"""Gemini client wrapper for STT evaluation.""" + +import logging +from typing import Any + +from google import genai +from sqlmodel import Session + +from app.core.exception_handlers import HTTPException +from app.crud.credentials import get_provider_credential + +logger = logging.getLogger(__name__) + + +class GeminiClientError(Exception): + """Exception raised for Gemini client errors.""" + + pass + + +class GeminiClient: + """Wrapper for Google GenAI client with credential management.""" + + def __init__(self, api_key: str) -> None: + """Initialize Gemini client with API key. + + Args: + api_key: Google AI API key + """ + self._api_key = api_key + self._client = genai.Client(api_key=api_key) + + @property + def client(self) -> genai.Client: + """Get the underlying GenAI client.""" + return self._client + + @classmethod + def from_credentials( + cls, + session: Session, + org_id: int, + project_id: int, + ) -> "GeminiClient": + """Create client from stored credentials. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + + Returns: + GeminiClient: Configured Gemini client + + Raises: + HTTPException: If credentials not found + GeminiClientError: If credentials are invalid + """ + logger.info( + f"[from_credentials] Fetching Gemini credentials | " + f"org_id: {org_id}, project_id: {project_id}" + ) + + credentials = get_provider_credential( + session=session, + org_id=org_id, + project_id=project_id, + provider="gemini", + ) + + if not credentials: + logger.error( + f"[from_credentials] Gemini credentials not found | " + f"org_id: {org_id}, project_id: {project_id}" + ) + raise HTTPException( + status_code=404, + detail="Gemini credentials not configured for this project", + ) + + api_key = credentials.get("api_key") + if not api_key: + logger.error( + f"[from_credentials] Invalid Gemini credentials (missing api_key) | " + f"org_id: {org_id}, project_id: {project_id}" + ) + raise GeminiClientError("Invalid Gemini credentials: missing api_key") + + logger.info( + f"[from_credentials] Gemini client created successfully | " + f"org_id: {org_id}, project_id: {project_id}" + ) + return cls(api_key=api_key) + + def validate_connection(self) -> bool: + """Validate that the client can connect to Gemini. + + Returns: + bool: True if connection is valid + """ + try: + # List models to verify connection + models = list(self._client.models.list()) + logger.info( + f"[validate_connection] Connection validated | " + f"available_models_count: {len(models)}" + ) + return True + except Exception as e: + logger.error( + f"[validate_connection] Connection validation failed | error: {str(e)}" + ) + return False diff --git a/backend/app/tests/api/routes/test_stt_evaluation.py b/backend/app/tests/api/routes/test_stt_evaluation.py new file mode 100644 index 000000000..f751b8d4e --- /dev/null +++ b/backend/app/tests/api/routes/test_stt_evaluation.py @@ -0,0 +1,642 @@ +"""Tests for STT evaluation API routes.""" + +import pytest +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.models import EvaluationDataset, File, FileType +from app.models.stt_evaluation import STTSample, EvaluationType +from app.crud.language import get_language_by_locale +from app.tests.utils.auth import TestAuthContext +from app.core.util import now + + +# Helper functions +def create_test_file( + db: Session, + organization_id: int, + project_id: int, + object_store_url: str = "s3://test-bucket/audio/test.mp3", + filename: str = "test.mp3", + size_bytes: int = 1024, + content_type: str = "audio/mpeg", + file_type: str = FileType.AUDIO.value, +) -> File: + """Create a test file record.""" + file = File( + object_store_url=object_store_url, + filename=filename, + size_bytes=size_bytes, + content_type=content_type, + file_type=file_type, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(file) + db.commit() + db.refresh(file) + return file + + +def create_test_stt_dataset( + db: Session, + organization_id: int, + project_id: int, + name: str = "test_stt_dataset", + description: str | None = None, + language_id: int | None = None, +) -> EvaluationDataset: + """Create a test STT dataset.""" + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.STT.value, + language_id=language_id, + dataset_metadata={"sample_count": 0, "has_ground_truth_count": 0}, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(dataset) + db.commit() + db.refresh(dataset) + return dataset + + +def create_test_stt_sample( + db: Session, + dataset_id: int, + organization_id: int, + project_id: int, + file_id: int | None = None, + ground_truth: str | None = None, +) -> STTSample: + """Create a test STT sample.""" + # If no file_id provided, create a test file first + if file_id is None: + file = create_test_file( + db=db, + organization_id=organization_id, + project_id=project_id, + ) + file_id = file.id + + sample = STTSample( + file_id=file_id, + ground_truth=ground_truth, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(sample) + db.commit() + db.refresh(sample) + return sample + + +class TestSTTDatasetCreate: + """Test POST /evaluations/stt/datasets endpoint.""" + + def test_create_stt_dataset_success( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test creating an STT dataset with samples.""" + # Get seeded English language + language = get_language_by_locale(session=db, locale="en") + file1 = create_test_file( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + object_store_url="s3://bucket/audio1.mp3", + filename="audio1.mp3", + ) + file2 = create_test_file( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + object_store_url="s3://bucket/audio2.mp3", + filename="audio2.mp3", + ) + + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "test_stt_dataset_create", + "description": "Test STT dataset", + "language_id": language.id, + "samples": [ + {"file_id": file1.id}, + { + "file_id": file2.id, + "ground_truth": "Hello world", + }, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["name"] == "test_stt_dataset_create" + assert data["description"] == "Test STT dataset" + assert data["type"] == "stt" + assert data["language_id"] == language.id + assert data["dataset_metadata"]["sample_count"] == 2 + assert data["dataset_metadata"]["has_ground_truth_count"] == 1 + + def test_create_stt_dataset_minimal( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test creating an STT dataset with minimal fields.""" + # Create a test file first + file = create_test_file( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + object_store_url="s3://bucket/audio.mp3", + filename="audio.mp3", + ) + + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "minimal_stt_dataset", + "samples": [ + {"file_id": file.id}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["name"] == "minimal_stt_dataset" + assert data["description"] is None + assert data["language_id"] is None + assert data["dataset_metadata"]["sample_count"] == 1 + + def test_create_stt_dataset_empty_samples( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test creating an STT dataset with empty samples fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "empty_samples_dataset", + "samples": [], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 422 + + def test_create_stt_dataset_missing_name( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test creating an STT dataset without name fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "samples": [ + {"file_id": 1}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 422 + + def test_create_stt_dataset_without_authentication( + self, + client: TestClient, + ) -> None: + """Test creating an STT dataset without authentication fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "unauthenticated_dataset", + "samples": [ + {"file_id": 1}, + ], + }, + ) + + assert response.status_code == 401 + + def test_create_stt_dataset_duplicate_name( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test creating an STT dataset with duplicate name fails.""" + # Create a test file first + file = create_test_file( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + # Create first dataset + create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="duplicate_name_test", + ) + + # Try to create another with same name + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "duplicate_name_test", + "samples": [ + {"file_id": file.id}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 400 + response_data = response.json() + error_str = response_data.get("detail", response_data.get("error", "")) + assert "already exists" in error_str.lower() + + def test_create_stt_dataset_invalid_file_id( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test creating an STT dataset with invalid file_id fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "invalid_file_dataset", + "samples": [ + {"file_id": 99999}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 400 + response_data = response.json() + error_str = response_data.get("detail", response_data.get("error", "")) + assert "not found" in error_str.lower() + + +class TestSTTDatasetList: + """Test GET /evaluations/stt/datasets endpoint.""" + + def test_list_stt_datasets_empty( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test listing STT datasets when none exist.""" + response = client.get( + "/api/v1/evaluations/stt/datasets", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert isinstance(response_data["data"], list) + + def test_list_stt_datasets_with_data( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test listing STT datasets with data.""" + # Create test datasets + dataset1 = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="list_test_dataset_1", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset1.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + dataset2 = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="list_test_dataset_2", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset2.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + response = client.get( + "/api/v1/evaluations/stt/datasets", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + assert len(data) >= 2 + + # Check that our datasets are in the list + names = [d["name"] for d in data] + assert "list_test_dataset_1" in names + assert "list_test_dataset_2" in names + + def test_list_stt_datasets_pagination( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test pagination for listing STT datasets.""" + # Create multiple datasets + for i in range(5): + create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name=f"pagination_test_dataset_{i}", + ) + + # Test with limit + response = client.get( + "/api/v1/evaluations/stt/datasets", + params={"limit": 2, "offset": 0}, + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["data"]) == 2 + assert response_data["metadata"]["limit"] == 2 + assert response_data["metadata"]["offset"] == 0 + + +class TestSTTDatasetGet: + """Test GET /evaluations/stt/datasets/{dataset_id} endpoint.""" + + def test_get_stt_dataset_success( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test getting an STT dataset by ID.""" + dataset = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="get_test_dataset", + description="Test description", + ) + sample = create_test_stt_sample( + db=db, + dataset_id=dataset.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ground_truth="Test transcription", + ) + + response = client.get( + f"/api/v1/evaluations/stt/datasets/{dataset.id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["id"] == dataset.id + assert data["name"] == "get_test_dataset" + assert data["description"] == "Test description" + assert data["type"] == "stt" + assert len(data["samples"]) == 1 + assert data["samples"][0]["id"] == sample.id + assert data["samples"][0]["ground_truth"] == "Test transcription" + + def test_get_stt_dataset_not_found( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test getting a non-existent STT dataset.""" + response = client.get( + "/api/v1/evaluations/stt/datasets/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + + def test_get_stt_dataset_without_samples( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test getting an STT dataset without including samples.""" + # Create dataset with sample_count in metadata set correctly + dataset = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="get_no_samples_dataset", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + # Update dataset metadata to reflect the sample count + dataset.dataset_metadata = {"sample_count": 1, "has_ground_truth_count": 0} + db.add(dataset) + db.commit() + + response = client.get( + f"/api/v1/evaluations/stt/datasets/{dataset.id}", + params={"include_samples": False}, + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + data = response_data["data"] + + assert data["id"] == dataset.id + assert data["samples"] == [] + assert data["dataset_metadata"]["sample_count"] == 1 + + +class TestSTTEvaluationRun: + """Test STT evaluation run endpoints.""" + + @pytest.fixture + def test_dataset_with_samples( + self, db: Session, user_api_key: TestAuthContext + ) -> EvaluationDataset: + """Create a test dataset with samples for evaluation.""" + dataset = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="eval_test_dataset", + ) + # Create some samples (file will be created automatically) + for i in range(3): + file = create_test_file( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + object_store_url=f"s3://bucket/audio_{i}.mp3", + filename=f"audio_{i}.mp3", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + file_id=file.id, + ) + return dataset + + def test_start_stt_evaluation_invalid_dataset( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test starting an STT evaluation with invalid dataset ID.""" + response = client.post( + "/api/v1/evaluations/stt/runs", + json={ + "run_name": "test_run", + "dataset_id": 99999, + "providers": ["gemini-2.5-pro"], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 404 + response_data = response.json() + error_str = response_data.get("detail", response_data.get("error", "")) + assert "not found" in error_str.lower() + + def test_start_stt_evaluation_without_authentication( + self, + client: TestClient, + ) -> None: + """Test starting an STT evaluation without authentication.""" + response = client.post( + "/api/v1/evaluations/stt/runs", + json={ + "run_name": "test_run", + "dataset_id": 1, + "providers": ["gemini-2.5-pro"], + }, + ) + + assert response.status_code == 401 + + def test_list_stt_runs_empty( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test listing STT runs when none exist.""" + response = client.get( + "/api/v1/evaluations/stt/runs", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert isinstance(response_data["data"], list) + + def test_get_stt_run_not_found( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test getting a non-existent STT run.""" + response = client.get( + "/api/v1/evaluations/stt/runs/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + + +class TestSTTResultFeedback: + """Test STT result feedback endpoint.""" + + def test_update_feedback_not_found( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test updating feedback for non-existent result.""" + response = client.patch( + "/api/v1/evaluations/stt/results/99999", + json={ + "is_correct": True, + "comment": "Test comment", + }, + headers=user_api_key_header, + ) + + assert response.status_code == 404 + + def test_update_feedback_without_authentication( + self, + client: TestClient, + ) -> None: + """Test updating feedback without authentication.""" + response = client.patch( + "/api/v1/evaluations/stt/results/1", + json={ + "is_correct": True, + }, + ) + + assert response.status_code == 401 diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py new file mode 100644 index 000000000..fc62f7b4c --- /dev/null +++ b/backend/app/tests/core/batch/test_gemini.py @@ -0,0 +1,616 @@ +"""Test cases for GeminiBatchProvider.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from app.core.batch.gemini import ( + BatchJobState, + GeminiBatchProvider, + create_stt_batch_requests, +) + + +class TestGeminiBatchProvider: + """Test cases for GeminiBatchProvider.""" + + @pytest.fixture + def mock_genai_client(self): + """Create a mock Gemini client.""" + return MagicMock() + + @pytest.fixture + def provider(self, mock_genai_client): + """Create a GeminiBatchProvider instance with mock client.""" + return GeminiBatchProvider(client=mock_genai_client) + + @pytest.fixture + def provider_with_model(self, mock_genai_client): + """Create a GeminiBatchProvider with custom model.""" + return GeminiBatchProvider( + client=mock_genai_client, model="models/gemini-2.5-flash" + ) + + def test_initialization_default_model(self, mock_genai_client): + """Test that provider initializes with default model.""" + provider = GeminiBatchProvider(client=mock_genai_client) + assert provider._client == mock_genai_client + assert provider._model == "models/gemini-2.5-pro" + + def test_initialization_custom_model(self, mock_genai_client): + """Test that provider initializes with custom model.""" + provider = GeminiBatchProvider( + client=mock_genai_client, model="models/gemini-2.5-flash" + ) + assert provider._model == "models/gemini-2.5-flash" + + def test_create_batch_success(self, provider, mock_genai_client): + """Test successful batch creation.""" + jsonl_data = [ + {"key": "req-1", "request": {"contents": [{"parts": [{"text": "test"}]}]}}, + {"key": "req-2", "request": {"contents": [{"parts": [{"text": "test2"}]}]}}, + ] + config = { + "display_name": "test-batch", + "model": "models/gemini-2.5-pro", + } + + # Mock file upload + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-123" + mock_genai_client.files.upload.return_value = mock_uploaded_file + + # Mock batch creation + mock_batch_job = MagicMock() + mock_batch_job.name = "batches/batch-xyz789" + mock_batch_job.state.name = "JOB_STATE_PENDING" + mock_genai_client.batches.create.return_value = mock_batch_job + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + result = provider.create_batch(jsonl_data, config) + + assert result["provider_batch_id"] == "batches/batch-xyz789" + assert result["provider_file_id"] == "files/uploaded-123" + assert result["provider_status"] == "JOB_STATE_PENDING" + assert result["total_items"] == 2 + + def test_create_batch_with_default_config(self, provider, mock_genai_client): + """Test batch creation with default configuration values.""" + jsonl_data = [{"key": "req-1", "request": {}}] + config = {} + + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-456" + mock_genai_client.files.upload.return_value = mock_uploaded_file + + mock_batch_job = MagicMock() + mock_batch_job.name = "batches/batch-123" + mock_batch_job.state.name = "JOB_STATE_PENDING" + mock_genai_client.batches.create.return_value = mock_batch_job + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + result = provider.create_batch(jsonl_data, config) + + assert result["total_items"] == 1 + mock_genai_client.batches.create.assert_called_once() + + def test_create_batch_file_upload_error(self, provider, mock_genai_client): + """Test handling of file upload error during batch creation.""" + jsonl_data = [{"key": "req-1", "request": {}}] + config = {"display_name": "test"} + + mock_genai_client.files.upload.side_effect = Exception("File upload failed") + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.create_batch(jsonl_data, config) + + assert "File upload failed" in str(exc_info.value) + + def test_create_batch_batch_creation_error(self, provider, mock_genai_client): + """Test handling of batch creation error.""" + jsonl_data = [{"key": "req-1", "request": {}}] + config = {"display_name": "test"} + + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-123" + mock_genai_client.files.upload.return_value = mock_uploaded_file + mock_genai_client.batches.create.side_effect = Exception( + "Batch creation failed" + ) + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.create_batch(jsonl_data, config) + + assert "Batch creation failed" in str(exc_info.value) + + def test_get_batch_status_pending(self, provider, mock_genai_client): + """Test getting status of a pending batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_PENDING" + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + mock_genai_client.batches.get.assert_called_once_with(name=batch_id) + assert result["provider_status"] == "JOB_STATE_PENDING" + assert result["provider_output_file_id"] == batch_id + assert "error_message" not in result + + def test_get_batch_status_succeeded(self, provider, mock_genai_client): + """Test getting status of a succeeded batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_SUCCEEDED" + assert "error_message" not in result + + def test_get_batch_status_failed(self, provider, mock_genai_client): + """Test getting status of a failed batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_FAILED" + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_FAILED" + assert "error_message" in result + assert "Batch JOB_STATE_FAILED" in result["error_message"] + + def test_get_batch_status_cancelled(self, provider, mock_genai_client): + """Test getting status of a cancelled batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_CANCELLED" + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_CANCELLED" + assert "error_message" in result + + def test_get_batch_status_expired(self, provider, mock_genai_client): + """Test getting status of an expired batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_EXPIRED" + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_EXPIRED" + assert "error_message" in result + + def test_get_batch_status_error(self, provider, mock_genai_client): + """Test handling of error when retrieving batch status.""" + batch_id = "batches/batch-xyz789" + + mock_genai_client.batches.get.side_effect = Exception("API connection failed") + + with pytest.raises(Exception) as exc_info: + provider.get_batch_status(batch_id) + + assert "API connection failed" in str(exc_info.value) + + def test_download_batch_results_success(self, provider, mock_genai_client): + """Test successful download of batch results.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = ( + '{"key":"req-1","response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}\n' + '{"key":"req-2","response":{"candidates":[{"content":{"parts":[{"text":"World"}]}}]}}' + ) + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + mock_genai_client.batches.get.assert_called_once_with(name=batch_id) + mock_genai_client.files.download.assert_called_once_with( + file="files/output-123" + ) + assert len(results) == 2 + assert results[0]["custom_id"] == "req-1" + assert results[0]["response"]["text"] == "Hello" + assert results[1]["custom_id"] == "req-2" + assert results[1]["response"]["text"] == "World" + + def test_download_batch_results_with_direct_text_response( + self, provider, mock_genai_client + ): + """Test downloading results with direct text in response.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = '{"key":"req-1","response":{"text":"Direct text"}}' + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + assert len(results) == 1 + assert results[0]["response"]["text"] == "Direct text" + + def test_download_batch_results_with_errors(self, provider, mock_genai_client): + """Test downloading batch results that contain errors.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = ( + '{"key":"req-1","response":{"text":"Success"}}\n' + '{"key":"req-2","error":{"message":"Invalid request"}}' + ) + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + assert len(results) == 2 + assert results[0]["custom_id"] == "req-1" + assert results[0]["response"] is not None + assert results[0]["error"] is None + assert results[1]["custom_id"] == "req-2" + assert results[1]["response"] is None + assert "Invalid request" in results[1]["error"] + + def test_download_batch_results_batch_not_complete( + self, provider, mock_genai_client + ): + """Test error when trying to download results from incomplete batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_RUNNING" + mock_genai_client.batches.get.return_value = mock_batch_job + + with pytest.raises(ValueError) as exc_info: + provider.download_batch_results(batch_id) + + assert "Batch job not complete" in str(exc_info.value) + + def test_download_batch_results_malformed_json(self, provider, mock_genai_client): + """Test handling of malformed JSON in batch results.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = ( + '{"key":"req-1","response":{"text":"Valid"}}\n' + "this is not valid json\n" + '{"key":"req-3","response":{"text":"Also valid"}}' + ) + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + # Should skip the malformed line and process the rest + assert len(results) == 2 + assert results[0]["custom_id"] == "req-1" + assert results[1]["custom_id"] == "req-3" + + def test_download_batch_results_no_dest_file(self, provider, mock_genai_client): + """Test handling when batch has no destination file.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" + mock_batch_job.dest = None + mock_genai_client.batches.get.return_value = mock_batch_job + + results = provider.download_batch_results(batch_id) + + assert len(results) == 0 + + def test_upload_file_success(self, provider, mock_genai_client): + """Test successful file upload.""" + content = '{"key":"req-1","request":{}}\n{"key":"req-2","request":{}}' + + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-abc123" + mock_genai_client.files.upload.return_value = mock_uploaded_file + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_temp_file = MagicMock() + mock_temp_file.name = "/tmp/test.jsonl" + mock_temp.return_value.__enter__.return_value = mock_temp_file + + with patch("os.unlink"): + file_name = provider.upload_file(content, purpose="batch") + + assert file_name == "files/uploaded-abc123" + mock_genai_client.files.upload.assert_called_once() + + def test_upload_file_error(self, provider, mock_genai_client): + """Test handling of error during file upload.""" + content = '{"key":"req-1"}' + + mock_genai_client.files.upload.side_effect = Exception("Upload quota exceeded") + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.upload_file(content) + + assert "Upload quota exceeded" in str(exc_info.value) + + def test_download_file_success(self, provider, mock_genai_client): + """Test successful file download.""" + file_id = "files/abc123" + expected_content = '{"key":"req-1","response":{"text":"test"}}' + + mock_genai_client.files.download.return_value = expected_content.encode("utf-8") + + content = provider.download_file(file_id) + + mock_genai_client.files.download.assert_called_once_with(file=file_id) + assert content == expected_content + + def test_download_file_unicode_content(self, provider, mock_genai_client): + """Test downloading file with unicode content.""" + file_id = "files/abc123" + expected_content = '{"text":"Hello δΈ–η•Œ 🌍"}' + + mock_genai_client.files.download.return_value = expected_content.encode("utf-8") + + content = provider.download_file(file_id) + + assert content == expected_content + assert "δΈ–η•Œ" in content + assert "🌍" in content + + def test_download_file_error(self, provider, mock_genai_client): + """Test handling of error during file download.""" + file_id = "files/abc123" + + mock_genai_client.files.download.side_effect = Exception("File not found") + + with pytest.raises(Exception) as exc_info: + provider.download_file(file_id) + + assert "File not found" in str(exc_info.value) + + +class TestBatchJobState: + """Test cases for BatchJobState enum.""" + + def test_batch_job_states(self): + """Test that all batch job states have correct values.""" + assert BatchJobState.PENDING.value == "JOB_STATE_PENDING" + assert BatchJobState.RUNNING.value == "JOB_STATE_RUNNING" + assert BatchJobState.SUCCEEDED.value == "JOB_STATE_SUCCEEDED" + assert BatchJobState.FAILED.value == "JOB_STATE_FAILED" + assert BatchJobState.CANCELLED.value == "JOB_STATE_CANCELLED" + assert BatchJobState.EXPIRED.value == "JOB_STATE_EXPIRED" + + +class TestCreateSTTBatchRequests: + """Test cases for create_stt_batch_requests function.""" + + def test_create_requests_with_keys(self): + """Test creating batch requests with custom keys.""" + signed_urls = [ + "https://bucket.s3.amazonaws.com/audio1.mp3?signature=abc", + "https://bucket.s3.amazonaws.com/audio2.wav?signature=def", + ] + prompt = "Transcribe this audio file." + keys = ["sample-1", "sample-2"] + + requests = create_stt_batch_requests(signed_urls, prompt, keys=keys) + + assert len(requests) == 2 + assert requests[0]["key"] == "sample-1" + assert requests[1]["key"] == "sample-2" + + # Verify structure + assert "request" in requests[0] + assert "contents" in requests[0]["request"] + contents = requests[0]["request"]["contents"] + assert len(contents) == 1 + assert contents[0]["role"] == "user" + assert len(contents[0]["parts"]) == 2 + assert contents[0]["parts"][0]["text"] == prompt + assert "file_data" in contents[0]["parts"][1] + + def test_create_requests_without_keys(self): + """Test creating batch requests without keys (auto-generated).""" + signed_urls = [ + "https://bucket.s3.amazonaws.com/audio.mp3?signature=xyz", + ] + prompt = "Transcribe." + + requests = create_stt_batch_requests(signed_urls, prompt) + + assert len(requests) == 1 + assert requests[0]["key"] == "0" + + def test_create_requests_mime_type_detection(self): + """Test that MIME types are correctly detected from URLs.""" + signed_urls = [ + "https://bucket.s3.amazonaws.com/audio.mp3?sig=1", + "https://bucket.s3.amazonaws.com/audio.wav?sig=2", + "https://bucket.s3.amazonaws.com/audio.m4a?sig=3", + ] + prompt = "Transcribe." + + requests = create_stt_batch_requests(signed_urls, prompt) + + assert ( + requests[0]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] + == "audio/mpeg" + ) + assert ( + requests[1]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] + == "audio/x-wav" + ) + # .m4a can return different MIME types depending on the system + m4a_mime = requests[2]["request"]["contents"][0]["parts"][1]["file_data"][ + "mime_type" + ] + assert m4a_mime in ("audio/mp4", "audio/mp4a-latm", "audio/x-m4a") + + def test_create_requests_key_length_mismatch(self): + """Test that mismatched keys and URLs raise error.""" + signed_urls = [ + "https://example.com/audio1.mp3", + "https://example.com/audio2.mp3", + ] + keys = ["only-one-key"] + prompt = "Transcribe." + + with pytest.raises(ValueError) as exc_info: + create_stt_batch_requests(signed_urls, prompt, keys=keys) + + assert "Length of keys" in str(exc_info.value) + + def test_create_requests_file_uri_preserved(self): + """Test that signed URLs are preserved in file_uri.""" + signed_url = "https://bucket.s3.amazonaws.com/audio.mp3?X-Amz-Signature=abc123&X-Amz-Expires=3600" + prompt = "Transcribe." + + requests = create_stt_batch_requests([signed_url], prompt) + + file_uri = requests[0]["request"]["contents"][0]["parts"][1]["file_data"][ + "file_uri" + ] + assert file_uri == signed_url + assert "X-Amz-Signature" in file_uri + + +class TestExtractTextFromResponseDict: + """Test cases for _extract_text_from_response_dict static method.""" + + def test_extract_direct_text(self): + """Test extracting text from direct text field.""" + response = {"text": "Hello world"} + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "Hello world" + + def test_extract_from_candidates(self): + """Test extracting text from candidates structure.""" + response = { + "candidates": [ + { + "content": { + "parts": [ + {"text": "Part 1"}, + {"text": " Part 2"}, + ] + } + } + ] + } + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "Part 1 Part 2" + + def test_extract_empty_response(self): + """Test extracting text from empty response.""" + response = {} + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "" + + def test_extract_multiple_candidates(self): + """Test extracting text from multiple candidates.""" + response = { + "candidates": [ + {"content": {"parts": [{"text": "First"}]}}, + {"content": {"parts": [{"text": "Second"}]}}, + ] + } + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "FirstSecond" + + +class TestExtractTextFromResponse: + """Test cases for _extract_text_from_response static method (object version).""" + + def test_extract_direct_text_attribute(self): + """Test extracting text when response has .text attribute.""" + response = MagicMock() + response.text = "Hello from text attribute" + text = GeminiBatchProvider._extract_text_from_response(response) + assert text == "Hello from text attribute" + + def test_extract_from_candidates_structure(self): + """Test extracting text from candidates when no .text attribute.""" + # Create mock without .text attribute + response = MagicMock(spec=[]) + del response.text # Ensure no text attribute + + # Create candidates structure + part1 = MagicMock() + part1.text = "Part 1" + part2 = MagicMock() + part2.text = " Part 2" + + content = MagicMock() + content.parts = [part1, part2] + + candidate = MagicMock() + candidate.content = content + + response.candidates = [candidate] + + text = GeminiBatchProvider._extract_text_from_response(response) + assert text == "Part 1 Part 2" + + def test_extract_empty_response_no_text_no_candidates(self): + """Test extracting text from response with no text and no candidates.""" + response = MagicMock(spec=[]) + del response.text + del response.candidates + + text = GeminiBatchProvider._extract_text_from_response(response) + assert text == "" + + +class TestCreateSttBatchRequestsMimeTypeFallback: + """Test cases for create_stt_batch_requests MIME type fallback.""" + + def test_unknown_mime_type_defaults_to_audio_mpeg(self): + """Test that unknown file extensions default to audio/mpeg.""" + # URL with no recognizable audio extension + signed_urls = ["https://bucket.s3.amazonaws.com/audio.unknown?signature=xyz"] + prompt = "Transcribe this audio." + + with patch("app.core.batch.gemini.get_mime_from_url", return_value=None): + requests = create_stt_batch_requests(signed_urls, prompt) + + assert len(requests) == 1 + # Check that the request was created with default mime type + # parts[0] is text prompt, parts[1] is file_data + file_data = requests[0]["request"]["contents"][0]["parts"][1]["file_data"] + assert file_data["mime_type"] == "audio/mpeg" diff --git a/backend/app/tests/core/test_storage_utils.py b/backend/app/tests/core/test_storage_utils.py new file mode 100644 index 000000000..e1b702003 --- /dev/null +++ b/backend/app/tests/core/test_storage_utils.py @@ -0,0 +1,293 @@ +"""Test cases for storage utilities.""" + +import json +from datetime import datetime +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest + +from app.core.storage_utils import ( + generate_timestamped_filename, + get_mime_from_url, + upload_csv_to_object_store, + upload_jsonl_to_object_store, + upload_to_object_store, +) + + +class TestGetMimeFromUrl: + """Test cases for get_mime_from_url function.""" + + def test_mp3_url(self): + """Test MIME detection for MP3 files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.mp3" + assert get_mime_from_url(url) == "audio/mpeg" + + def test_wav_url(self): + """Test MIME detection for WAV files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.wav" + assert get_mime_from_url(url) == "audio/x-wav" + + def test_flac_url(self): + """Test MIME detection for FLAC files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.flac" + mime = get_mime_from_url(url) + # FLAC can be detected as audio/flac or audio/x-flac depending on system + assert mime in ("audio/flac", "audio/x-flac") + + def test_ogg_url(self): + """Test MIME detection for OGG files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.ogg" + assert get_mime_from_url(url) == "audio/ogg" + + def test_webm_url(self): + """Test MIME detection for WEBM files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.webm" + mime = get_mime_from_url(url) + # webm can be detected as audio or video depending on system + assert mime in ("audio/webm", "video/webm") + + def test_signed_url_with_query_params(self): + """Test MIME detection for signed URLs with query parameters.""" + url = ( + "https://bucket.s3.amazonaws.com/audio/test.mp3" + "?X-Amz-Signature=abc123&X-Amz-Expires=3600" + ) + assert get_mime_from_url(url) == "audio/mpeg" + + def test_url_encoded_path(self): + """Test MIME detection for URL-encoded paths.""" + url = "https://bucket.s3.amazonaws.com/audio/test%20file.mp3" + assert get_mime_from_url(url) == "audio/mpeg" + + def test_unknown_extension(self): + """Test MIME detection returns None for unknown extensions.""" + url = "https://bucket.s3.amazonaws.com/file.unknown" + assert get_mime_from_url(url) is None + + def test_no_extension(self): + """Test MIME detection returns None for URLs without extension.""" + url = "https://bucket.s3.amazonaws.com/file" + assert get_mime_from_url(url) is None + + def test_csv_url(self): + """Test MIME detection for CSV files.""" + url = "https://bucket.s3.amazonaws.com/data/test.csv" + assert get_mime_from_url(url) == "text/csv" + + def test_json_url(self): + """Test MIME detection for JSON files.""" + url = "https://bucket.s3.amazonaws.com/data/test.json" + assert get_mime_from_url(url) == "application/json" + + +class TestUploadToObjectStore: + """Test cases for upload_to_object_store function.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock CloudStorage instance.""" + storage = MagicMock() + storage.put.return_value = "s3://bucket/test/file.txt" + return storage + + def test_successful_upload(self, mock_storage): + """Test successful file upload.""" + content = b"test content" + result = upload_to_object_store( + storage=mock_storage, + content=content, + filename="test.txt", + subdirectory="uploads", + content_type="text/plain", + ) + + assert result == "s3://bucket/test/file.txt" + mock_storage.put.assert_called_once() + + # Verify the UploadFile was created correctly + call_args = mock_storage.put.call_args + upload_file = call_args.kwargs["source"] + assert upload_file.filename == "test.txt" + + def test_upload_with_default_content_type(self, mock_storage): + """Test upload uses default content type.""" + content = b"binary data" + result = upload_to_object_store( + storage=mock_storage, + content=content, + filename="data.bin", + subdirectory="files", + ) + + assert result == "s3://bucket/test/file.txt" + mock_storage.put.assert_called_once() + + def test_upload_returns_none_on_cloud_storage_error(self, mock_storage): + """Test that CloudStorageError returns None gracefully.""" + from app.core.cloud.storage import CloudStorageError + + mock_storage.put.side_effect = CloudStorageError("Connection failed") + + result = upload_to_object_store( + storage=mock_storage, + content=b"data", + filename="test.txt", + subdirectory="uploads", + ) + + assert result is None + + def test_upload_returns_none_on_generic_error(self, mock_storage): + """Test that generic exceptions return None gracefully.""" + mock_storage.put.side_effect = Exception("Unexpected error") + + result = upload_to_object_store( + storage=mock_storage, + content=b"data", + filename="test.txt", + subdirectory="uploads", + ) + + assert result is None + + +class TestUploadCsvToObjectStore: + """Test cases for upload_csv_to_object_store function.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock CloudStorage instance.""" + storage = MagicMock() + storage.put.return_value = "s3://bucket/datasets/data.csv" + return storage + + def test_successful_csv_upload(self, mock_storage): + """Test successful CSV upload.""" + csv_content = b"col1,col2\nval1,val2" + result = upload_csv_to_object_store( + storage=mock_storage, + csv_content=csv_content, + filename="data.csv", + ) + + assert result == "s3://bucket/datasets/data.csv" + mock_storage.put.assert_called_once() + + def test_csv_upload_with_custom_subdirectory(self, mock_storage): + """Test CSV upload with custom subdirectory.""" + csv_content = b"col1,col2\nval1,val2" + result = upload_csv_to_object_store( + storage=mock_storage, + csv_content=csv_content, + filename="data.csv", + subdirectory="stt_datasets", + ) + + assert result is not None + call_args = mock_storage.put.call_args + file_path = call_args.kwargs["file_path"] + assert "stt_datasets" in str(file_path) + + +class TestUploadJsonlToObjectStore: + """Test cases for upload_jsonl_to_object_store function.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock CloudStorage instance.""" + storage = MagicMock() + storage.put.return_value = "s3://bucket/results/output.jsonl" + return storage + + def test_successful_jsonl_upload(self, mock_storage): + """Test successful JSONL upload.""" + results = [ + {"id": 1, "text": "result 1"}, + {"id": 2, "text": "result 2"}, + ] + result = upload_jsonl_to_object_store( + storage=mock_storage, + results=results, + filename="output.jsonl", + subdirectory="results", + ) + + assert result == "s3://bucket/results/output.jsonl" + mock_storage.put.assert_called_once() + + def test_empty_results_list(self, mock_storage): + """Test upload with empty results list.""" + results = [] + result = upload_jsonl_to_object_store( + storage=mock_storage, + results=results, + filename="empty.jsonl", + subdirectory="results", + ) + + assert result is not None + + def test_results_with_unicode(self, mock_storage): + """Test upload with unicode content.""" + results = [ + {"text": "Hello δΈ–η•Œ"}, + {"text": "Emoji 🎡"}, + ] + result = upload_jsonl_to_object_store( + storage=mock_storage, + results=results, + filename="unicode.jsonl", + subdirectory="results", + ) + + assert result is not None + + +class TestGenerateTimestampedFilename: + """Test cases for generate_timestamped_filename function.""" + + def test_default_csv_extension(self): + """Test that default extension is CSV.""" + filename = generate_timestamped_filename("dataset") + assert filename.endswith(".csv") + assert filename.startswith("dataset_") + + def test_custom_extension(self): + """Test custom file extension.""" + filename = generate_timestamped_filename("results", extension="jsonl") + assert filename.endswith(".jsonl") + assert filename.startswith("results_") + + def test_timestamp_format(self): + """Test that timestamp is in expected format.""" + filename = generate_timestamped_filename("test") + # Expected format: test_YYYYMMDD_HHMMSS.csv + parts = filename.split("_") + assert len(parts) == 3 + # Date part should be 8 digits + assert len(parts[1]) == 8 + assert parts[1].isdigit() + # Time part should be 6 digits + extension + time_part = parts[2].split(".")[0] + assert len(time_part) == 6 + assert time_part.isdigit() + + def test_with_special_characters_in_base_name(self): + """Test with special characters in base name.""" + filename = generate_timestamped_filename("my-dataset-v1") + assert filename.startswith("my-dataset-v1_") + assert filename.endswith(".csv") + + def test_unique_filenames(self): + """Test that consecutive calls produce different filenames.""" + import time + + filename1 = generate_timestamped_filename("test") + time.sleep(0.01) # Small delay to ensure different timestamp + filename2 = generate_timestamped_filename("test") + # They may be the same if called in the same second + # but the format should be correct + assert filename1.startswith("test_") + assert filename2.startswith("test_") diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 9e1bec372..78bed64a7 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -208,7 +208,7 @@ def test_invalid_provider(db: Session) -> None: """Test handling of invalid provider names.""" project = create_test_project(db) - credentials_data = {"gemini": {"api_key": "test-key"}} + credentials_data = {"invalid_provider": {"api_key": "test-key"}} credentials_create = CredsCreate( is_active=True, credential=credentials_data, diff --git a/backend/app/tests/services/stt_evaluations/__init__.py b/backend/app/tests/services/stt_evaluations/__init__.py new file mode 100644 index 000000000..a206798fe --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/__init__.py @@ -0,0 +1 @@ +"""STT evaluation service tests.""" diff --git a/backend/app/tests/services/stt_evaluations/test_audio.py b/backend/app/tests/services/stt_evaluations/test_audio.py new file mode 100644 index 000000000..a1c1d2342 --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/test_audio.py @@ -0,0 +1,211 @@ +"""Test cases for STT audio upload service.""" + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import UploadFile + +from app.core.exception_handlers import HTTPException +from app.services.stt_evaluations.audio import ( + _resolve_extension, + _validate_audio_file, + upload_audio_file, +) +from app.services.stt_evaluations.constants import MAX_FILE_SIZE_BYTES + + +def _make_upload_file( + filename: str | None = "test.mp3", + content_type: str | None = "audio/mpeg", + size: int | None = 1024, +) -> UploadFile: + """Create a mock UploadFile for testing.""" + mock_file = MagicMock(spec=UploadFile) + mock_file.filename = filename + mock_file.content_type = content_type + mock_file.size = size + return mock_file + + +class TestResolveExtension: + """Test cases for _resolve_extension.""" + + @pytest.mark.parametrize( + "filename, expected", + [ + ("audio.mp3", "mp3"), + ("audio.wav", "wav"), + ("audio.flac", "flac"), + ("audio.m4a", "m4a"), + ("audio.ogg", "ogg"), + ("audio.webm", "webm"), + ("audio.MP3", "mp3"), + ("audio.Mp3", "mp3"), + ("audio.backup.mp3", "mp3"), + (".audio.mp3", "mp3"), + ("/path/to/audio.mp3", "mp3"), + ], + ) + def test_extension_from_filename(self, filename: str, expected: str) -> None: + file = _make_upload_file(filename=filename) + assert _resolve_extension(file) == expected + + @pytest.mark.parametrize( + "filename", + ["", None, "audiofile"], + ) + def test_no_extension_from_filename(self, filename: str | None) -> None: + file = _make_upload_file(filename=filename, content_type=None) + assert _resolve_extension(file) is None + + @pytest.mark.parametrize( + "content_type, expected", + [ + ("audio/mpeg", "mp3"), + ("audio/mp3", "mp3"), + ("audio/wav", "wav"), + ("audio/x-wav", "wav"), + ("audio/wave", "wav"), + ("audio/flac", "flac"), + ("audio/mp4", "m4a"), + ("audio/ogg", "ogg"), + ("audio/webm", "webm"), + ("AUDIO/MPEG", "mp3"), + ], + ) + def test_fallback_to_content_type(self, content_type: str, expected: str) -> None: + file = _make_upload_file(filename="audiofile", content_type=content_type) + assert _resolve_extension(file) == expected + + def test_unknown_content_type(self) -> None: + file = _make_upload_file( + filename="audiofile", content_type="application/octet-stream" + ) + assert _resolve_extension(file) is None + + +class TestValidateAudioFile: + """Test cases for _validate_audio_file.""" + + @pytest.mark.parametrize("ext", ["mp3", "wav", "flac", "m4a", "ogg", "webm"]) + def test_valid_formats(self, ext: str) -> None: + file = _make_upload_file(filename=f"test.{ext}") + assert _validate_audio_file(file) == ext + + def test_missing_filename(self) -> None: + file = _make_upload_file(filename=None) + with pytest.raises(HTTPException) as exc_info: + _validate_audio_file(file) + assert exc_info.value.status_code == 400 + assert "Filename is required" in str(exc_info.value.detail) + + def test_empty_filename(self) -> None: + file = _make_upload_file(filename="") + with pytest.raises(HTTPException) as exc_info: + _validate_audio_file(file) + assert exc_info.value.status_code == 400 + + def test_unsupported_format(self) -> None: + file = _make_upload_file(filename="test.txt") + with pytest.raises(HTTPException) as exc_info: + _validate_audio_file(file) + assert exc_info.value.status_code == 400 + assert "Unsupported audio format" in str(exc_info.value.detail) + + def test_content_type_fallback(self) -> None: + file = _make_upload_file(filename="audiofile", content_type="audio/mpeg") + assert _validate_audio_file(file) == "mp3" + + def test_file_too_large(self) -> None: + file = _make_upload_file(size=MAX_FILE_SIZE_BYTES + 1) + with pytest.raises(HTTPException) as exc_info: + _validate_audio_file(file) + assert exc_info.value.status_code == 400 + assert "File too large" in str(exc_info.value.detail) + + def test_file_at_max_size(self) -> None: + file = _make_upload_file(size=MAX_FILE_SIZE_BYTES) + assert _validate_audio_file(file) == "mp3" + + def test_file_with_no_size(self) -> None: + file = _make_upload_file(size=None) + assert _validate_audio_file(file) == "mp3" + + +class TestUploadAudioFile: + """Test cases for upload_audio_file.""" + + @patch("app.services.stt_evaluations.audio.create_file") + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_successful_upload(self, mock_get_storage, mock_create_file) -> None: + mock_storage = MagicMock() + mock_storage.put.return_value = "s3://bucket/stt/audio/test.mp3" + mock_storage.get_file_size_kb.return_value = 1.0 + mock_get_storage.return_value = mock_storage + + mock_file_record = MagicMock() + mock_file_record.id = 1 + mock_create_file.return_value = mock_file_record + + result = upload_audio_file( + session=MagicMock(), + file=_make_upload_file(), + organization_id=1, + project_id=1, + ) + + assert result.file_id == 1 + assert result.s3_url == "s3://bucket/stt/audio/test.mp3" + assert result.filename == "test.mp3" + assert result.size_bytes == 1024 + assert result.content_type == "audio/mpeg" + + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_upload_validation_error(self, mock_get_storage) -> None: + with pytest.raises(HTTPException) as exc_info: + upload_audio_file( + session=MagicMock(), + file=_make_upload_file(filename="test.txt"), + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 400 + assert "Unsupported audio format" in str(exc_info.value.detail) + + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_upload_storage_error(self, mock_get_storage) -> None: + mock_storage = MagicMock() + mock_storage.put.side_effect = Exception("S3 connection failed") + mock_get_storage.return_value = mock_storage + + with pytest.raises(HTTPException) as exc_info: + upload_audio_file( + session=MagicMock(), + file=_make_upload_file(), + organization_id=1, + project_id=1, + ) + assert exc_info.value.status_code == 500 + assert "Failed to upload audio file" in str(exc_info.value.detail) + + @patch("app.services.stt_evaluations.audio.create_file") + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_upload_uses_file_size_on_s3_error( + self, mock_get_storage, mock_create_file + ) -> None: + mock_storage = MagicMock() + mock_storage.put.return_value = "s3://bucket/stt/audio/test.mp3" + mock_storage.get_file_size_kb.side_effect = Exception("Failed to get size") + mock_get_storage.return_value = mock_storage + + mock_file_record = MagicMock() + mock_file_record.id = 1 + mock_create_file.return_value = mock_file_record + + result = upload_audio_file( + session=MagicMock(), + file=_make_upload_file(size=2048), + organization_id=1, + project_id=1, + ) + assert result.size_bytes == 2048 diff --git a/backend/app/tests/services/stt_evaluations/test_dataset.py b/backend/app/tests/services/stt_evaluations/test_dataset.py new file mode 100644 index 000000000..eff49a107 --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/test_dataset.py @@ -0,0 +1,344 @@ +"""Test cases for STT dataset management service.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from app.models.stt_evaluation import STTSampleCreate +from app.services.stt_evaluations.dataset import ( + _samples_to_csv, + _upload_samples_to_object_store, + upload_stt_dataset, +) + + +class TestSamplesToCSV: + """Test cases for _samples_to_csv function.""" + + def test_single_sample_without_ground_truth(self): + """Test CSV conversion with single sample without ground truth.""" + samples = [ + STTSampleCreate(file_id=1), + ] + result = _samples_to_csv(samples) + + # Decode and verify - handle both \n and \r\n line endings + csv_str = result.decode("utf-8").replace("\r\n", "\n").replace("\r", "\n") + lines = csv_str.strip().split("\n") + + assert len(lines) == 2 # Header + 1 sample + assert lines[0] == "file_id,ground_truth" + assert lines[1] == "1," + + def test_single_sample_with_ground_truth(self): + """Test CSV conversion with single sample with ground truth.""" + samples = [ + STTSampleCreate( + file_id=1, + ground_truth="Hello world", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + lines = csv_str.strip().split("\n") + + assert len(lines) == 2 + assert "Hello world" in lines[1] + + def test_multiple_samples(self): + """Test CSV conversion with multiple samples.""" + samples = [ + STTSampleCreate( + file_id=1, + ground_truth="First transcription", + ), + STTSampleCreate( + file_id=2, + ground_truth="Second transcription", + ), + STTSampleCreate( + file_id=3, + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + lines = csv_str.strip().split("\n") + + assert len(lines) == 4 # Header + 3 samples + + def test_empty_samples_list(self): + """Test CSV conversion with empty samples list.""" + samples = [] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + lines = csv_str.strip().split("\n") + + # Should only have header + assert len(lines) == 1 + assert lines[0] == "file_id,ground_truth" + + def test_sample_with_unicode(self): + """Test CSV conversion with unicode characters.""" + samples = [ + STTSampleCreate( + file_id=1, + ground_truth="Hello δΈ–η•Œ 🌍", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + assert "Hello δΈ–η•Œ 🌍" in csv_str + + def test_sample_with_comma_in_ground_truth(self): + """Test CSV conversion handles commas in ground truth.""" + samples = [ + STTSampleCreate( + file_id=1, + ground_truth="Hello, world", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + # CSV should properly quote fields with commas + assert '"Hello, world"' in csv_str or "Hello, world" in csv_str + + def test_sample_with_quotes_in_ground_truth(self): + """Test CSV conversion handles quotes in ground truth.""" + samples = [ + STTSampleCreate( + file_id=1, + ground_truth='He said "hello"', + ), + ] + result = _samples_to_csv(samples) + + # Should produce valid CSV bytes + assert isinstance(result, bytes) + + +class TestUploadSamplesToObjectStore: + """Test cases for _upload_samples_to_object_store function.""" + + @patch("app.services.stt_evaluations.dataset.get_cloud_storage") + @patch("app.services.stt_evaluations.dataset.upload_csv_to_object_store") + def test_successful_upload(self, mock_upload_csv, mock_get_storage): + """Test successful upload to object store.""" + mock_storage = MagicMock() + mock_get_storage.return_value = mock_storage + mock_upload_csv.return_value = "s3://bucket/stt_datasets/dataset.csv" + + mock_session = MagicMock() + samples = [ + STTSampleCreate(file_id=1), + ] + + result = _upload_samples_to_object_store( + session=mock_session, + project_id=1, + dataset_name="test_dataset", + samples=samples, + ) + + assert result == "s3://bucket/stt_datasets/dataset.csv" + mock_upload_csv.assert_called_once() + + @patch("app.services.stt_evaluations.dataset.get_cloud_storage") + @patch("app.services.stt_evaluations.dataset.upload_csv_to_object_store") + def test_upload_returns_none_on_failure(self, mock_upload_csv, mock_get_storage): + """Test upload returns None on failure.""" + mock_storage = MagicMock() + mock_get_storage.return_value = mock_storage + mock_upload_csv.return_value = None + + mock_session = MagicMock() + samples = [ + STTSampleCreate(file_id=1), + ] + + result = _upload_samples_to_object_store( + session=mock_session, + project_id=1, + dataset_name="test_dataset", + samples=samples, + ) + + assert result is None + + @patch("app.services.stt_evaluations.dataset.get_cloud_storage") + def test_upload_handles_exception(self, mock_get_storage): + """Test upload handles exceptions gracefully.""" + mock_get_storage.side_effect = Exception("Storage connection failed") + + mock_session = MagicMock() + samples = [ + STTSampleCreate(file_id=1), + ] + + result = _upload_samples_to_object_store( + session=mock_session, + project_id=1, + dataset_name="test_dataset", + samples=samples, + ) + + assert result is None + + +class TestUploadSTTDataset: + """Test cases for upload_stt_dataset function.""" + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_successful_upload( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test successful dataset upload orchestration.""" + # Setup mocks + mock_upload_samples.return_value = "s3://bucket/stt_datasets/test.csv" + + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + + mock_sample = MagicMock() + mock_create_samples.return_value = [mock_sample] + + mock_session = MagicMock() + samples = [ + STTSampleCreate( + file_id=1, + ground_truth="Test transcription", + ), + ] + + dataset, created_samples = upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + ) + + assert dataset == mock_dataset + assert len(created_samples) == 1 + + # Verify create_stt_dataset was called with correct metadata + mock_create_dataset.assert_called_once() + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["name"] == "test_dataset" + assert call_kwargs["dataset_metadata"]["sample_count"] == 1 + assert call_kwargs["dataset_metadata"]["has_ground_truth_count"] == 1 + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_upload_with_description_and_language( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test upload with optional description and language.""" + mock_upload_samples.return_value = None + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + mock_create_samples.return_value = [] + + mock_session = MagicMock() + samples = [ + STTSampleCreate(file_id=1), + ] + + upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + description="Test description", + language="en", + ) + + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["description"] == "Test description" + assert call_kwargs["language"] == "en" + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_upload_counts_ground_truth_correctly( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test that ground truth count is calculated correctly.""" + mock_upload_samples.return_value = None + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + mock_create_samples.return_value = [] + + mock_session = MagicMock() + samples = [ + STTSampleCreate( + file_id=1, + ground_truth="Has ground truth", + ), + STTSampleCreate( + file_id=2, + ), + STTSampleCreate( + file_id=3, + ground_truth="Also has ground truth", + ), + STTSampleCreate( + file_id=4, + ground_truth="", # Empty string should not count + ), + ] + + upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + ) + + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["dataset_metadata"]["sample_count"] == 4 + assert call_kwargs["dataset_metadata"]["has_ground_truth_count"] == 2 + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_upload_continues_without_object_store_url( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test that upload continues even when object store upload fails.""" + mock_upload_samples.return_value = None # Simulates failed upload + + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + mock_create_samples.return_value = [] + + mock_session = MagicMock() + samples = [ + STTSampleCreate(file_id=1), + ] + + dataset, created_samples = upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + ) + + # Should still create the dataset + assert dataset is not None + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["object_store_url"] is None diff --git a/backend/app/tests/services/stt_evaluations/test_gemini_client.py b/backend/app/tests/services/stt_evaluations/test_gemini_client.py new file mode 100644 index 000000000..36ef02fbb --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/test_gemini_client.py @@ -0,0 +1,182 @@ +"""Test cases for Gemini client wrapper.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.stt_evaluations.gemini.client import ( + GeminiClient, + GeminiClientError, +) + + +class TestGeminiClientInit: + """Test cases for GeminiClient initialization.""" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_initialization_with_api_key(self, mock_genai_client): + """Test client initialization with API key.""" + mock_client_instance = MagicMock() + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="test-api-key") + + mock_genai_client.assert_called_once_with(api_key="test-api-key") + assert client._api_key == "test-api-key" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_client_property(self, mock_genai_client): + """Test client property returns underlying client.""" + mock_client_instance = MagicMock() + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="test-api-key") + + assert client.client == mock_client_instance + + +class TestGeminiClientFromCredentials: + """Test cases for GeminiClient.from_credentials class method.""" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_successful_creation(self, mock_get_creds, mock_genai_client): + """Test successful client creation from credentials.""" + mock_get_creds.return_value = {"api_key": "stored-api-key"} + mock_client_instance = MagicMock() + mock_genai_client.return_value = mock_client_instance + + mock_session = MagicMock() + + client = GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + mock_get_creds.assert_called_once_with( + session=mock_session, + org_id=1, + project_id=2, + provider="gemini", + ) + assert client._api_key == "stored-api-key" + + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_credentials_not_found(self, mock_get_creds): + """Test error when credentials are not found.""" + from app.core.exception_handlers import HTTPException + + mock_get_creds.return_value = None + mock_session = MagicMock() + + with pytest.raises(HTTPException) as exc_info: + GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + assert exc_info.value.status_code == 404 + assert "credentials not configured" in str(exc_info.value.detail) + + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_missing_api_key(self, mock_get_creds): + """Test error when credentials exist but api_key is missing.""" + mock_get_creds.return_value = {"other_field": "value"} # No api_key + mock_session = MagicMock() + + with pytest.raises(GeminiClientError) as exc_info: + GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + assert "missing api_key" in str(exc_info.value) + + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_empty_api_key(self, mock_get_creds): + """Test error when api_key is empty.""" + mock_get_creds.return_value = {"api_key": ""} # Empty api_key + mock_session = MagicMock() + + with pytest.raises(GeminiClientError) as exc_info: + GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + assert "missing api_key" in str(exc_info.value) + + +class TestGeminiClientValidateConnection: + """Test cases for GeminiClient.validate_connection method.""" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_successful_validation(self, mock_genai_client): + """Test successful connection validation.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.return_value = [ + MagicMock(), + MagicMock(), + ] + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="valid-api-key") + result = client.validate_connection() + + assert result is True + mock_client_instance.models.list.assert_called_once() + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_validation_with_api_error(self, mock_genai_client): + """Test validation returns False on API error.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.side_effect = Exception("API error") + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="invalid-api-key") + result = client.validate_connection() + + assert result is False + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_validation_with_auth_error(self, mock_genai_client): + """Test validation returns False on authentication error.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.side_effect = Exception("Invalid API key") + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="wrong-api-key") + result = client.validate_connection() + + assert result is False + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_validation_with_empty_models_list(self, mock_genai_client): + """Test validation succeeds with empty models list.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.return_value = [] + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="valid-api-key") + result = client.validate_connection() + + # Empty list is still a valid response + assert result is True + + +class TestGeminiClientError: + """Test cases for GeminiClientError exception.""" + + def test_error_message(self): + """Test error message is preserved.""" + error = GeminiClientError("Test error message") + assert str(error) == "Test error message" + + def test_error_inheritance(self): + """Test error inherits from Exception.""" + error = GeminiClientError("Test") + assert isinstance(error, Exception) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6030fc0a1..119ee9116 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "celery>=5.3.0,<6.0.0", "redis>=5.0.0,<6.0.0", "flower>=2.0.1", + "google-genai>=1.0.0", + "requests>=2.31.0", ] [tool.uv] diff --git a/backend/uv.lock b/backend/uv.lock index fb79c631c..667e43f0d 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12, <4.0" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -206,6 +206,7 @@ dependencies = [ { name = "emails" }, { name = "fastapi", extra = ["standard"] }, { name = "flower" }, + { name = "google-genai" }, { name = "httpx" }, { name = "jinja2" }, { name = "langfuse" }, @@ -224,6 +225,7 @@ dependencies = [ { name = "pytest" }, { name = "python-multipart" }, { name = "redis" }, + { name = "requests" }, { name = "scikit-learn" }, { name = "sentry-sdk", extra = ["fastapi"] }, { name = "sqlmodel" }, @@ -252,6 +254,7 @@ requires-dist = [ { name = "emails", specifier = ">=0.6,<1.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.116.0" }, { name = "flower", specifier = ">=2.0.1" }, + { name = "google-genai", specifier = ">=1.0.0" }, { name = "httpx", specifier = ">=0.25.1,<1.0.0" }, { name = "jinja2", specifier = ">=3.1.4,<4.0.0" }, { name = "langfuse", specifier = "==2.60.3" }, @@ -270,6 +273,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.4.4" }, { name = "python-multipart", specifier = ">=0.0.22,<1.0.0" }, { name = "redis", specifier = ">=5.0.0,<6.0.0" }, + { name = "requests", specifier = ">=2.31.0" }, { name = "scikit-learn", specifier = ">=1.7.1" }, { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=2.20.0" }, { name = "sqlmodel", specifier = ">=0.0.21,<1.0.0" }, @@ -1108,6 +1112,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/c9/97cc5aae1648dcb851958a3ddf73ccd7dbe5650d95203ecb4d7720b4cdbf/fsspec-2026.1.0-py3-none-any.whl", hash = "sha256:cb76aa913c2285a3b49bdd5fc55b1d7c708d7208126b60f2eb8194fe1b4cbdcc", size = 201838, upload-time = "2026-01-09T15:21:34.041Z" }, ] +[[package]] +name = "google-auth" +version = "2.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0c/41/242044323fbd746615884b1c16639749e73665b718209946ebad7ba8a813/google_auth-2.48.0.tar.gz", hash = "sha256:4f7e706b0cd3208a3d940a19a822c37a476ddba5450156c3e6624a71f7c841ce", size = 326522, upload-time = "2026-01-26T19:22:47.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/1d/d6466de3a5249d35e832a52834115ca9d1d0de6abc22065f049707516d47/google_auth-2.48.0-py3-none-any.whl", hash = "sha256:2e2a537873d449434252a9632c28bfc268b0adb1e53f9fb62afc5333a975903f", size = 236499, upload-time = "2026-01-26T19:22:45.099Z" }, +] + +[package.optional-dependencies] +requests = [ + { name = "requests" }, +] + +[[package]] +name = "google-genai" +version = "1.61.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "google-auth", extra = ["requests"] }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "sniffio" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/38/421cd7e70952a536be87a0249409f87297d84f523754a25b08fe94b97e7f/google_genai-1.61.0.tar.gz", hash = "sha256:5773a4e8ad5b2ebcd54a633a67d8e9c4f413032fef07977ee47ffa34a6d3bbdf", size = 489672, upload-time = "2026-01-30T20:50:27.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/87/78dd70cb59f7acf3350f53c5144a7aa7bc39c6f425cd7dc1224b59fcdac3/google_genai-1.61.0-py3-none-any.whl", hash = "sha256:cb073ef8287581476c1c3f4d8e735426ee34478e500a56deef218fa93071e3ca", size = 721948, upload-time = "2026-01-30T20:50:25.551Z" }, +] + [[package]] name = "greenlet" version = "3.3.1" @@ -1117,6 +1161,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/c8/9d76a66421d1ae24340dfae7e79c313957f6e3195c144d2c73333b5bfe34/greenlet-3.3.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:7e806ca53acf6d15a888405880766ec84721aa4181261cd11a457dfe9a7a4975", size = 276443, upload-time = "2026-01-23T15:30:10.066Z" }, { url = "https://files.pythonhosted.org/packages/81/99/401ff34bb3c032d1f10477d199724f5e5f6fbfb59816ad1455c79c1eb8e7/greenlet-3.3.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d842c94b9155f1c9b3058036c24ffb8ff78b428414a19792b2380be9cecf4f36", size = 597359, upload-time = "2026-01-23T16:00:57.394Z" }, { url = "https://files.pythonhosted.org/packages/2b/bc/4dcc0871ed557792d304f50be0f7487a14e017952ec689effe2180a6ff35/greenlet-3.3.1-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20fedaadd422fa02695f82093f9a98bad3dab5fcda793c658b945fcde2ab27ba", size = 607805, upload-time = "2026-01-23T16:05:28.068Z" }, + { url = "https://files.pythonhosted.org/packages/3b/cd/7a7ca57588dac3389e97f7c9521cb6641fd8b6602faf1eaa4188384757df/greenlet-3.3.1-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c620051669fd04ac6b60ebc70478210119c56e2d5d5df848baec4312e260e4ca", size = 622363, upload-time = "2026-01-23T16:15:54.754Z" }, { url = "https://files.pythonhosted.org/packages/cf/05/821587cf19e2ce1f2b24945d890b164401e5085f9d09cbd969b0c193cd20/greenlet-3.3.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14194f5f4305800ff329cbf02c5fcc88f01886cadd29941b807668a45f0d2336", size = 609947, upload-time = "2026-01-23T15:32:51.004Z" }, { url = "https://files.pythonhosted.org/packages/a4/52/ee8c46ed9f8babaa93a19e577f26e3d28a519feac6350ed6f25f1afee7e9/greenlet-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7b2fe4150a0cf59f847a67db8c155ac36aed89080a6a639e9f16df5d6c6096f1", size = 1567487, upload-time = "2026-01-23T16:04:22.125Z" }, { url = "https://files.pythonhosted.org/packages/8f/7c/456a74f07029597626f3a6db71b273a3632aecb9afafeeca452cfa633197/greenlet-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:49f4ad195d45f4a66a0eb9c1ba4832bb380570d361912fa3554746830d332149", size = 1636087, upload-time = "2026-01-23T15:33:47.486Z" }, @@ -1125,6 +1170,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/ab/d26750f2b7242c2b90ea2ad71de70cfcd73a948a49513188a0fc0d6fc15a/greenlet-3.3.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:7ab327905cabb0622adca5971e488064e35115430cec2c35a50fd36e72a315b3", size = 275205, upload-time = "2026-01-23T15:30:24.556Z" }, { url = "https://files.pythonhosted.org/packages/10/d3/be7d19e8fad7c5a78eeefb2d896a08cd4643e1e90c605c4be3b46264998f/greenlet-3.3.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65be2f026ca6a176f88fb935ee23c18333ccea97048076aef4db1ef5bc0713ac", size = 599284, upload-time = "2026-01-23T16:00:58.584Z" }, { url = "https://files.pythonhosted.org/packages/ae/21/fe703aaa056fdb0f17e5afd4b5c80195bbdab701208918938bd15b00d39b/greenlet-3.3.1-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7a3ae05b3d225b4155bda56b072ceb09d05e974bc74be6c3fc15463cf69f33fd", size = 610274, upload-time = "2026-01-23T16:05:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/06/00/95df0b6a935103c0452dad2203f5be8377e551b8466a29650c4c5a5af6cc/greenlet-3.3.1-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:12184c61e5d64268a160226fb4818af4df02cfead8379d7f8b99a56c3a54ff3e", size = 624375, upload-time = "2026-01-23T16:15:55.915Z" }, { url = "https://files.pythonhosted.org/packages/cb/86/5c6ab23bb3c28c21ed6bebad006515cfe08b04613eb105ca0041fecca852/greenlet-3.3.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6423481193bbbe871313de5fd06a082f2649e7ce6e08015d2a76c1e9186ca5b3", size = 612904, upload-time = "2026-01-23T15:32:52.317Z" }, { url = "https://files.pythonhosted.org/packages/c2/f3/7949994264e22639e40718c2daf6f6df5169bf48fb038c008a489ec53a50/greenlet-3.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:33a956fe78bbbda82bfc95e128d61129b32d66bcf0a20a1f0c08aa4839ffa951", size = 1567316, upload-time = "2026-01-23T16:04:23.316Z" }, { url = "https://files.pythonhosted.org/packages/8d/6e/d73c94d13b6465e9f7cd6231c68abde838bb22408596c05d9059830b7872/greenlet-3.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b065d3284be43728dd280f6f9a13990b56470b81be20375a207cdc814a983f2", size = 1636549, upload-time = "2026-01-23T15:33:48.643Z" }, @@ -1133,6 +1179,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/fb/011c7c717213182caf78084a9bea51c8590b0afda98001f69d9f853a495b/greenlet-3.3.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:bd59acd8529b372775cd0fcbc5f420ae20681c5b045ce25bd453ed8455ab99b5", size = 275737, upload-time = "2026-01-23T15:32:16.889Z" }, { url = "https://files.pythonhosted.org/packages/41/2e/a3a417d620363fdbb08a48b1dd582956a46a61bf8fd27ee8164f9dfe87c2/greenlet-3.3.1-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b31c05dd84ef6871dd47120386aed35323c944d86c3d91a17c4b8d23df62f15b", size = 646422, upload-time = "2026-01-23T16:01:00.354Z" }, { url = "https://files.pythonhosted.org/packages/b4/09/c6c4a0db47defafd2d6bab8ddfe47ad19963b4e30f5bed84d75328059f8c/greenlet-3.3.1-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:02925a0bfffc41e542c70aa14c7eda3593e4d7e274bfcccca1827e6c0875902e", size = 658219, upload-time = "2026-01-23T16:05:30.956Z" }, + { url = "https://files.pythonhosted.org/packages/e2/89/b95f2ddcc5f3c2bc09c8ee8d77be312df7f9e7175703ab780f2014a0e781/greenlet-3.3.1-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3e0f3878ca3a3ff63ab4ea478585942b53df66ddde327b59ecb191b19dbbd62d", size = 671455, upload-time = "2026-01-23T16:15:57.232Z" }, { url = "https://files.pythonhosted.org/packages/80/38/9d42d60dffb04b45f03dbab9430898352dba277758640751dc5cc316c521/greenlet-3.3.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34a729e2e4e4ffe9ae2408d5ecaf12f944853f40ad724929b7585bca808a9d6f", size = 660237, upload-time = "2026-01-23T15:32:53.967Z" }, { url = "https://files.pythonhosted.org/packages/96/61/373c30b7197f9e756e4c81ae90a8d55dc3598c17673f91f4d31c3c689c3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aec9ab04e82918e623415947921dea15851b152b822661cce3f8e4393c3df683", size = 1615261, upload-time = "2026-01-23T16:04:25.066Z" }, { url = "https://files.pythonhosted.org/packages/fd/d3/ca534310343f5945316f9451e953dcd89b36fe7a19de652a1dc5a0eeef3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:71c767cf281a80d02b6c1bdc41c9468e1f5a494fb11bc8688c360524e273d7b1", size = 1683719, upload-time = "2026-01-23T15:33:50.61Z" }, @@ -1141,6 +1188,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/24/cbbec49bacdcc9ec652a81d3efef7b59f326697e7edf6ed775a5e08e54c2/greenlet-3.3.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:3e63252943c921b90abb035ebe9de832c436401d9c45f262d80e2d06cc659242", size = 282706, upload-time = "2026-01-23T15:33:05.525Z" }, { url = "https://files.pythonhosted.org/packages/86/2e/4f2b9323c144c4fe8842a4e0d92121465485c3c2c5b9e9b30a52e80f523f/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76e39058e68eb125de10c92524573924e827927df5d3891fbc97bd55764a8774", size = 651209, upload-time = "2026-01-23T16:01:01.517Z" }, { url = "https://files.pythonhosted.org/packages/d9/87/50ca60e515f5bb55a2fbc5f0c9b5b156de7d2fc51a0a69abc9d23914a237/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c9f9d5e7a9310b7a2f416dd13d2e3fd8b42d803968ea580b7c0f322ccb389b97", size = 654300, upload-time = "2026-01-23T16:05:32.199Z" }, + { url = "https://files.pythonhosted.org/packages/7c/25/c51a63f3f463171e09cb586eb64db0861eb06667ab01a7968371a24c4f3b/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4b9721549a95db96689458a1e0ae32412ca18776ed004463df3a9299c1b257ab", size = 662574, upload-time = "2026-01-23T16:15:58.364Z" }, { url = "https://files.pythonhosted.org/packages/1d/94/74310866dfa2b73dd08659a3d18762f83985ad3281901ba0ee9a815194fb/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92497c78adf3ac703b57f1e3813c2d874f27f71a178f9ea5887855da413cd6d2", size = 653842, upload-time = "2026-01-23T15:32:55.671Z" }, { url = "https://files.pythonhosted.org/packages/97/43/8bf0ffa3d498eeee4c58c212a3905dd6146c01c8dc0b0a046481ca29b18c/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ed6b402bc74d6557a705e197d47f9063733091ed6357b3de33619d8a8d93ac53", size = 1614917, upload-time = "2026-01-23T16:04:26.276Z" }, { url = "https://files.pythonhosted.org/packages/89/90/a3be7a5f378fc6e84abe4dcfb2ba32b07786861172e502388b4c90000d1b/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:59913f1e5ada20fde795ba906916aea25d442abcc0593fba7e26c92b7ad76249", size = 1676092, upload-time = "2026-01-23T15:33:52.176Z" }, @@ -2404,6 +2452,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/60/22c9716033ced1ee1d800457126c4c79652a4ed635b0554c1d93742cc0a1/py_zerox-0.0.7-py3-none-any.whl", hash = "sha256:7b7d92cb6fafec91a94b63ba3c039b643fb3ee83545b15fa330ec07dd52f2058", size = 23347, upload-time = "2024-10-21T16:03:33.406Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -3013,6 +3082,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" }, ] +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + [[package]] name = "ruff" version = "0.14.14" @@ -3639,47 +3720,33 @@ wheels = [ [[package]] name = "websockets" -version = "16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, - { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, - { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, - { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, - { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, - { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, - { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, - { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, - { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, - { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, - { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, - { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, - { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, - { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, - { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, - { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, - { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, - { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, - { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, - { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, - { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, - { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, - { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, - { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, - { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, - { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, - { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, - { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, - { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, - { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, - { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, - { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, - { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +version = "15.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/e6/26d09fab466b7ca9c7737474c52be4f76a40301b08362eb2dbc19dcc16c1/websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee", size = 177016, upload-time = "2025-03-05T20:03:41.606Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/6b/4545a0d843594f5d0771e86463606a3988b5a09ca5123136f8a76580dd63/websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3", size = 175437, upload-time = "2025-03-05T20:02:16.706Z" }, + { url = "https://files.pythonhosted.org/packages/f4/71/809a0f5f6a06522af902e0f2ea2757f71ead94610010cf570ab5c98e99ed/websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665", size = 173096, upload-time = "2025-03-05T20:02:18.832Z" }, + { url = "https://files.pythonhosted.org/packages/3d/69/1a681dd6f02180916f116894181eab8b2e25b31e484c5d0eae637ec01f7c/websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2", size = 173332, upload-time = "2025-03-05T20:02:20.187Z" }, + { url = "https://files.pythonhosted.org/packages/a6/02/0073b3952f5bce97eafbb35757f8d0d54812b6174ed8dd952aa08429bcc3/websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215", size = 183152, upload-time = "2025-03-05T20:02:22.286Z" }, + { url = "https://files.pythonhosted.org/packages/74/45/c205c8480eafd114b428284840da0b1be9ffd0e4f87338dc95dc6ff961a1/websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5", size = 182096, upload-time = "2025-03-05T20:02:24.368Z" }, + { url = "https://files.pythonhosted.org/packages/14/8f/aa61f528fba38578ec553c145857a181384c72b98156f858ca5c8e82d9d3/websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65", size = 182523, upload-time = "2025-03-05T20:02:25.669Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6d/0267396610add5bc0d0d3e77f546d4cd287200804fe02323797de77dbce9/websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe", size = 182790, upload-time = "2025-03-05T20:02:26.99Z" }, + { url = "https://files.pythonhosted.org/packages/02/05/c68c5adbf679cf610ae2f74a9b871ae84564462955d991178f95a1ddb7dd/websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4", size = 182165, upload-time = "2025-03-05T20:02:30.291Z" }, + { url = "https://files.pythonhosted.org/packages/29/93/bb672df7b2f5faac89761cb5fa34f5cec45a4026c383a4b5761c6cea5c16/websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597", size = 182160, upload-time = "2025-03-05T20:02:31.634Z" }, + { url = "https://files.pythonhosted.org/packages/ff/83/de1f7709376dc3ca9b7eeb4b9a07b4526b14876b6d372a4dc62312bebee0/websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9", size = 176395, upload-time = "2025-03-05T20:02:33.017Z" }, + { url = "https://files.pythonhosted.org/packages/7d/71/abf2ebc3bbfa40f391ce1428c7168fb20582d0ff57019b69ea20fa698043/websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7", size = 176841, upload-time = "2025-03-05T20:02:34.498Z" }, + { url = "https://files.pythonhosted.org/packages/cb/9f/51f0cf64471a9d2b4d0fc6c534f323b664e7095640c34562f5182e5a7195/websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931", size = 175440, upload-time = "2025-03-05T20:02:36.695Z" }, + { url = "https://files.pythonhosted.org/packages/8a/05/aa116ec9943c718905997412c5989f7ed671bc0188ee2ba89520e8765d7b/websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675", size = 173098, upload-time = "2025-03-05T20:02:37.985Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0b/33cef55ff24f2d92924923c99926dcce78e7bd922d649467f0eda8368923/websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151", size = 173329, upload-time = "2025-03-05T20:02:39.298Z" }, + { url = "https://files.pythonhosted.org/packages/31/1d/063b25dcc01faa8fada1469bdf769de3768b7044eac9d41f734fd7b6ad6d/websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22", size = 183111, upload-time = "2025-03-05T20:02:40.595Z" }, + { url = "https://files.pythonhosted.org/packages/93/53/9a87ee494a51bf63e4ec9241c1ccc4f7c2f45fff85d5bde2ff74fcb68b9e/websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f", size = 182054, upload-time = "2025-03-05T20:02:41.926Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b2/83a6ddf56cdcbad4e3d841fcc55d6ba7d19aeb89c50f24dd7e859ec0805f/websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8", size = 182496, upload-time = "2025-03-05T20:02:43.304Z" }, + { url = "https://files.pythonhosted.org/packages/98/41/e7038944ed0abf34c45aa4635ba28136f06052e08fc2168520bb8b25149f/websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375", size = 182829, upload-time = "2025-03-05T20:02:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/e0/17/de15b6158680c7623c6ef0db361da965ab25d813ae54fcfeae2e5b9ef910/websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d", size = 182217, upload-time = "2025-03-05T20:02:50.14Z" }, + { url = "https://files.pythonhosted.org/packages/33/2b/1f168cb6041853eef0362fb9554c3824367c5560cbdaad89ac40f8c2edfc/websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4", size = 182195, upload-time = "2025-03-05T20:02:51.561Z" }, + { url = "https://files.pythonhosted.org/packages/86/eb/20b6cdf273913d0ad05a6a14aed4b9a85591c18a987a3d47f20fa13dcc47/websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa", size = 176393, upload-time = "2025-03-05T20:02:53.814Z" }, + { url = "https://files.pythonhosted.org/packages/1b/6c/c65773d6cab416a64d191d6ee8a8b1c68a09970ea6909d16965d26bfed1e/websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561", size = 176837, upload-time = "2025-03-05T20:02:55.237Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] [[package]]