From 011d489a2abd433e52701053830a40d1b4c87653 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Fri, 30 Jan 2026 14:15:41 +0530 Subject: [PATCH 01/36] first stab at STT evals --- .../versions/042_add_stt_evaluation_tables.py | 356 +++++++++ .../api/routes/stt_evaluations/__init__.py | 5 + .../app/api/routes/stt_evaluations/dataset.py | 204 +++++ .../api/routes/stt_evaluations/evaluation.py | 227 ++++++ .../app/api/routes/stt_evaluations/files.py | 50 ++ .../app/api/routes/stt_evaluations/result.py | 134 ++++ .../app/api/routes/stt_evaluations/router.py | 13 + backend/app/celery/tasks/stt_evaluation.py | 705 ++++++++++++++++++ backend/app/core/batch/gemini.py | 348 +++++++++ backend/app/crud/stt_evaluations/__init__.py | 42 ++ backend/app/crud/stt_evaluations/dataset.py | 341 +++++++++ backend/app/crud/stt_evaluations/result.py | 413 ++++++++++ backend/app/crud/stt_evaluations/run.py | 302 ++++++++ backend/app/models/stt_evaluation.py | 414 ++++++++++ .../app/services/stt_evaluations/__init__.py | 1 + backend/app/services/stt_evaluations/audio.py | 212 ++++++ .../app/services/stt_evaluations/constants.py | 32 + .../stt_evaluations/gemini/__init__.py | 6 + .../services/stt_evaluations/gemini/client.py | 113 +++ .../services/stt_evaluations/gemini/files.py | 271 +++++++ 20 files changed, 4189 insertions(+) create mode 100644 backend/app/alembic/versions/042_add_stt_evaluation_tables.py create mode 100644 backend/app/api/routes/stt_evaluations/__init__.py create mode 100644 backend/app/api/routes/stt_evaluations/dataset.py create mode 100644 backend/app/api/routes/stt_evaluations/evaluation.py create mode 100644 backend/app/api/routes/stt_evaluations/files.py create mode 100644 backend/app/api/routes/stt_evaluations/result.py create mode 100644 backend/app/api/routes/stt_evaluations/router.py create mode 100644 backend/app/celery/tasks/stt_evaluation.py create mode 100644 backend/app/core/batch/gemini.py create mode 100644 backend/app/crud/stt_evaluations/__init__.py create mode 100644 backend/app/crud/stt_evaluations/dataset.py create mode 100644 backend/app/crud/stt_evaluations/result.py create mode 100644 backend/app/crud/stt_evaluations/run.py create mode 100644 backend/app/models/stt_evaluation.py create mode 100644 backend/app/services/stt_evaluations/__init__.py create mode 100644 backend/app/services/stt_evaluations/audio.py create mode 100644 backend/app/services/stt_evaluations/constants.py create mode 100644 backend/app/services/stt_evaluations/gemini/__init__.py create mode 100644 backend/app/services/stt_evaluations/gemini/client.py create mode 100644 backend/app/services/stt_evaluations/gemini/files.py diff --git a/backend/app/alembic/versions/042_add_stt_evaluation_tables.py b/backend/app/alembic/versions/042_add_stt_evaluation_tables.py new file mode 100644 index 000000000..ff95b078e --- /dev/null +++ b/backend/app/alembic/versions/042_add_stt_evaluation_tables.py @@ -0,0 +1,356 @@ +"""add stt evaluation tables + +Revision ID: 042 +Revises: 041 +Create Date: 2026-01-28 12:00:00.000000 + +""" + +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "042" +down_revision = "041" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add type and language columns to evaluation_dataset table + op.add_column( + "evaluation_dataset", + sa.Column( + "type", + sa.String(length=20), + nullable=False, + server_default="text", + comment="Evaluation type: text, stt, or tts", + ), + ) + op.add_column( + "evaluation_dataset", + sa.Column( + "language", + sa.String(length=10), + nullable=True, + comment="ISO 639-1 language code (e.g., en, hi)", + ), + ) + + # Add type, language, providers, and processed_samples columns to evaluation_run table + op.add_column( + "evaluation_run", + sa.Column( + "type", + sa.String(length=20), + nullable=False, + server_default="text", + comment="Evaluation type: text, stt, or tts", + ), + ) + op.add_column( + "evaluation_run", + sa.Column( + "language", + sa.String(length=10), + nullable=True, + comment="ISO 639-1 language code", + ), + ) + op.add_column( + "evaluation_run", + sa.Column( + "providers", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="List of STT/TTS providers used (e.g., ['gemini-2.5-pro'])", + ), + ) + op.add_column( + "evaluation_run", + sa.Column( + "processed_samples", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + comment="Number of samples processed so far", + ), + ) + + # Create stt_sample table + op.create_table( + "stt_sample", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the STT sample", + ), + sa.Column( + "object_store_url", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + comment="S3 URL of the audio file", + ), + sa.Column( + "language", + sa.String(length=10), + nullable=True, + comment="ISO 639-1 language code for this sample", + ), + sa.Column( + "ground_truth", + sa.Text(), + nullable=True, + comment="Reference transcription for comparison (optional)", + ), + sa.Column( + "duration_seconds", + sa.Float(), + nullable=True, + comment="Audio duration in seconds", + ), + sa.Column( + "sample_metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + server_default=sa.text("'{}'::jsonb"), + comment="Additional metadata (format, bitrate, original filename, etc.)", + ), + sa.Column( + "dataset_id", + sa.Integer(), + nullable=False, + comment="Reference to the parent evaluation dataset", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the sample was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the sample was last updated", + ), + sa.ForeignKeyConstraint( + ["dataset_id"], + ["evaluation_dataset.id"], + name="fk_stt_sample_dataset_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_stt_sample_dataset_id", + "stt_sample", + ["dataset_id"], + unique=False, + ) + op.create_index( + "idx_stt_sample_org_project", + "stt_sample", + ["organization_id", "project_id"], + unique=False, + ) + + # Create stt_result table + op.create_table( + "stt_result", + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the STT result", + ), + sa.Column( + "transcription", + sa.Text(), + nullable=True, + comment="Generated transcription from STT provider", + ), + sa.Column( + "provider", + sa.String(length=50), + nullable=False, + comment="STT provider used (e.g., gemini-2.5-pro)", + ), + sa.Column( + "status", + sa.String(length=20), + nullable=False, + server_default="pending", + comment="Result status: pending, completed, failed", + ), + sa.Column( + "wer", + sa.Float(), + nullable=True, + comment="Word Error Rate (null for Phase 1)", + ), + sa.Column( + "cer", + sa.Float(), + nullable=True, + comment="Character Error Rate (null for Phase 1)", + ), + sa.Column( + "is_correct", + sa.Boolean(), + nullable=True, + comment="Human feedback: transcription correctness (null=not reviewed)", + ), + sa.Column( + "comment", + sa.Text(), + nullable=True, + comment="Human feedback comment", + ), + sa.Column( + "provider_metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + server_default=sa.text("'{}'::jsonb"), + comment="Provider-specific response metadata (tokens, latency, etc.)", + ), + sa.Column( + "error_message", + sa.Text(), + nullable=True, + comment="Error message if transcription failed", + ), + sa.Column( + "stt_sample_id", + sa.Integer(), + nullable=False, + comment="Reference to the STT sample", + ), + sa.Column( + "evaluation_run_id", + sa.Integer(), + nullable=False, + comment="Reference to the evaluation run", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the result was last updated", + ), + sa.ForeignKeyConstraint( + ["stt_sample_id"], + ["stt_sample.id"], + name="fk_stt_result_sample_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["evaluation_run_id"], + ["evaluation_run.id"], + name="fk_stt_result_run_id", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organization.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["project.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_stt_result_sample_id", + "stt_result", + ["stt_sample_id"], + unique=False, + ) + op.create_index( + "ix_stt_result_run_id", + "stt_result", + ["evaluation_run_id"], + unique=False, + ) + op.create_index( + "idx_stt_result_feedback", + "stt_result", + ["evaluation_run_id", "is_correct"], + unique=False, + ) + op.create_index( + "idx_stt_result_status", + "stt_result", + ["evaluation_run_id", "status"], + unique=False, + ) + + +def downgrade(): + # Drop stt_result table + op.drop_index("idx_stt_result_status", table_name="stt_result") + op.drop_index("idx_stt_result_feedback", table_name="stt_result") + op.drop_index("ix_stt_result_run_id", table_name="stt_result") + op.drop_index("ix_stt_result_sample_id", table_name="stt_result") + op.drop_table("stt_result") + + # Drop stt_sample table + op.drop_index("idx_stt_sample_org_project", table_name="stt_sample") + op.drop_index("ix_stt_sample_dataset_id", table_name="stt_sample") + op.drop_table("stt_sample") + + # Remove columns from evaluation_run table + op.drop_column("evaluation_run", "processed_samples") + op.drop_column("evaluation_run", "providers") + op.drop_column("evaluation_run", "language") + op.drop_column("evaluation_run", "type") + + # Remove columns from evaluation_dataset table + op.drop_column("evaluation_dataset", "language") + op.drop_column("evaluation_dataset", "type") diff --git a/backend/app/api/routes/stt_evaluations/__init__.py b/backend/app/api/routes/stt_evaluations/__init__.py new file mode 100644 index 000000000..fa71b5c0d --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/__init__.py @@ -0,0 +1,5 @@ +"""STT Evaluation API routes.""" + +from .router import router + +__all__ = ["router"] diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py new file mode 100644 index 000000000..c38b22f76 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -0,0 +1,204 @@ +"""STT dataset API routes.""" + +import logging +from typing import Any + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.stt_evaluations import ( + create_stt_dataset, + create_stt_samples, + get_stt_dataset_by_id, + list_stt_datasets, + get_samples_by_dataset_id, + get_sample_count_for_dataset, + update_dataset_metadata, +) +from app.models.stt_evaluation import ( + STTDatasetCreate, + STTDatasetPublic, + STTDatasetWithSamples, + STTSampleCreate, + STTSamplePublic, +) +from app.utils import APIResponse + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/datasets", + response_model=APIResponse[STTDatasetPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Create STT dataset", + description=""" +Create a new STT evaluation dataset with audio samples. + +Each sample requires: +- **object_store_url**: S3 URL of the audio file (from /evaluations/stt/files/audio endpoint) +- **language**: ISO 639-1 language code (optional) +- **ground_truth**: Reference transcription (optional, for Phase 2 WER/CER) +""", +) +def create_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_create: STTDatasetCreate = Body(...), +) -> APIResponse[STTDatasetPublic]: + """Create an STT evaluation dataset.""" + logger.info( + f"[create_dataset] Creating STT dataset | " + f"name: {dataset_create.name}, sample_count: {len(dataset_create.samples)}" + ) + + # Create dataset + dataset = create_stt_dataset( + session=_session, + name=dataset_create.name, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + description=dataset_create.description, + language=dataset_create.language, + dataset_metadata={ + "sample_count": len(dataset_create.samples), + "has_ground_truth_count": sum( + 1 for s in dataset_create.samples if s.ground_truth + ), + }, + ) + + # Create samples + samples = create_stt_samples( + session=_session, + dataset_id=dataset.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + samples=dataset_create.samples, + ) + + return APIResponse.success_response( + data=STTDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language=dataset.language, + dataset_metadata=dataset.dataset_metadata, + sample_count=len(samples), + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + ) + + +@router.get( + "/datasets", + response_model=APIResponse[list[STTDatasetPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List STT datasets", + description="List all STT evaluation datasets for the current project.", +) +def list_datasets( + _session: SessionDep, + auth_context: AuthContextDep, + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[STTDatasetPublic]]: + """List STT evaluation datasets.""" + datasets, total = list_stt_datasets( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=datasets, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/datasets/{dataset_id}", + response_model=APIResponse[STTDatasetWithSamples], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get STT dataset", + description="Get an STT dataset with its samples.", +) +def get_dataset( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int, + include_samples: bool = Query(True, description="Include samples in response"), + sample_limit: int = Query(100, ge=1, le=1000, description="Max samples to return"), + sample_offset: int = Query(0, ge=0, description="Sample offset"), +) -> APIResponse[STTDatasetWithSamples]: + """Get an STT evaluation dataset.""" + dataset = get_stt_dataset_by_id( + session=_session, + dataset_id=dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + samples = [] + samples_total = 0 + + if include_samples: + sample_records, samples_total = get_samples_by_dataset_id( + session=_session, + dataset_id=dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=sample_limit, + offset=sample_offset, + ) + + samples = [ + STTSamplePublic( + id=s.id, + object_store_url=s.object_store_url, + language=s.language, + ground_truth=s.ground_truth, + duration_seconds=s.duration_seconds, + sample_metadata=s.sample_metadata, + dataset_id=s.dataset_id, + organization_id=s.organization_id, + project_id=s.project_id, + inserted_at=s.inserted_at, + updated_at=s.updated_at, + ) + for s in sample_records + ] + else: + samples_total = get_sample_count_for_dataset( + session=_session, dataset_id=dataset_id + ) + + return APIResponse.success_response( + data=STTDatasetWithSamples( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language=dataset.language, + dataset_metadata=dataset.dataset_metadata, + sample_count=samples_total, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + samples=samples, + ), + metadata={"samples_total": samples_total}, + ) diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py new file mode 100644 index 000000000..1a976cab3 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -0,0 +1,227 @@ +"""STT evaluation run API routes.""" + +import logging + +from asgi_correlation_id import correlation_id +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.celery.tasks.stt_evaluation import process_stt_evaluation +from app.crud.stt_evaluations import ( + create_stt_run, + get_stt_dataset_by_id, + get_stt_run_by_id, + list_stt_runs, + get_sample_count_for_dataset, +) +from app.crud.stt_evaluations.result import get_results_by_run_id +from app.models.stt_evaluation import ( + STTEvaluationRunCreate, + STTEvaluationRunPublic, + STTEvaluationRunWithResults, +) +from app.utils import APIResponse + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/runs", + response_model=APIResponse[STTEvaluationRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Start STT evaluation", + description=""" +Start an STT evaluation run on a dataset. + +The evaluation will: +1. Process each audio sample through the specified providers +2. Generate transcriptions using Gemini Batch API +3. Store results for human review + +**Supported providers:** gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash +""", +) +def start_stt_evaluation( + _session: SessionDep, + auth_context: AuthContextDep, + run_create: STTEvaluationRunCreate = Body(...), +) -> APIResponse[STTEvaluationRunPublic]: + """Start an STT evaluation run.""" + logger.info( + f"[start_stt_evaluation] Starting STT evaluation | " + f"run_name: {run_create.run_name}, dataset_id: {run_create.dataset_id}, " + f"providers: {run_create.providers}" + ) + + # Validate dataset exists + dataset = get_stt_dataset_by_id( + session=_session, + dataset_id=run_create.dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not dataset: + raise HTTPException(status_code=404, detail="Dataset not found") + + # Get sample count for total items + sample_count = get_sample_count_for_dataset( + session=_session, dataset_id=run_create.dataset_id + ) + + if sample_count == 0: + raise HTTPException(status_code=400, detail="Dataset has no samples") + + # Create run record + run = create_stt_run( + session=_session, + run_name=run_create.run_name, + dataset_id=run_create.dataset_id, + dataset_name=dataset.name, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + providers=run_create.providers, + language=run_create.language or dataset.language, + total_items=sample_count * len(run_create.providers), + ) + + # Enqueue Celery task + trace_id = correlation_id.get() or "" + + process_stt_evaluation.apply_async( + kwargs={ + "evaluation_run_id": run.id, + "org_id": auth_context.organization_.id, + "project_id": auth_context.project_.id, + "trace_id": trace_id, + }, + ) + + logger.info( + f"[start_stt_evaluation] STT evaluation queued | " + f"run_id: {run.id}, task queued" + ) + + return APIResponse.success_response( + data=STTEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language=run.language, + providers=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + processed_samples=run.processed_samples, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + ) + + +@router.get( + "/runs", + response_model=APIResponse[list[STTEvaluationRunPublic]], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="List STT evaluation runs", + description="List all STT evaluation runs for the current project.", +) +def list_stt_evaluation_runs( + _session: SessionDep, + auth_context: AuthContextDep, + dataset_id: int | None = Query(None, description="Filter by dataset ID"), + status: str | None = Query(None, description="Filter by status"), + limit: int = Query(50, ge=1, le=100, description="Maximum results to return"), + offset: int = Query(0, ge=0, description="Number of results to skip"), +) -> APIResponse[list[STTEvaluationRunPublic]]: + """List STT evaluation runs.""" + runs, total = list_stt_runs( + session=_session, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + dataset_id=dataset_id, + status=status, + limit=limit, + offset=offset, + ) + + return APIResponse.success_response( + data=runs, + metadata={"total": total, "limit": limit, "offset": offset}, + ) + + +@router.get( + "/runs/{run_id}", + response_model=APIResponse[STTEvaluationRunWithResults], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get STT evaluation run", + description="Get an STT evaluation run with its results.", +) +def get_stt_evaluation_run( + _session: SessionDep, + auth_context: AuthContextDep, + run_id: int, + include_results: bool = Query(True, description="Include results in response"), + result_limit: int = Query(100, ge=1, le=1000, description="Max results to return"), + result_offset: int = Query(0, ge=0, description="Result offset"), + provider: str | None = Query(None, description="Filter results by provider"), + status: str | None = Query(None, description="Filter results by status"), +) -> APIResponse[STTEvaluationRunWithResults]: + """Get an STT evaluation run with results.""" + run = get_stt_run_by_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not run: + raise HTTPException(status_code=404, detail="Evaluation run not found") + + results = [] + results_total = 0 + + if include_results: + results, results_total = get_results_by_run_id( + session=_session, + run_id=run_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + provider=provider, + status=status, + limit=result_limit, + offset=result_offset, + ) + + return APIResponse.success_response( + data=STTEvaluationRunWithResults( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language=run.language, + providers=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + processed_samples=run.processed_samples, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + results=results, + results_total=results_total, + ), + metadata={"results_total": results_total}, + ) diff --git a/backend/app/api/routes/stt_evaluations/files.py b/backend/app/api/routes/stt_evaluations/files.py new file mode 100644 index 000000000..08decb22f --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/files.py @@ -0,0 +1,50 @@ +"""Audio file upload API routes for STT evaluation.""" + +import logging + +from fastapi import APIRouter, Depends, File, UploadFile + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models.stt_evaluation import AudioUploadResponse +from app.services.stt_evaluations.audio import upload_audio_file +from app.utils import APIResponse + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/files/audio", + response_model=APIResponse[AudioUploadResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Upload audio file", + description=""" +Upload a single audio file to S3 for STT evaluation. + +**Supported formats:** mp3, wav, flac, m4a, ogg, webm + +**Maximum file size:** 200 MB + +Returns the S3 URL which can be used when creating an STT dataset. +""", +) +def upload_audio( + _session: SessionDep, + auth_context: AuthContextDep, + file: UploadFile = File(..., description="Audio file to upload"), +) -> APIResponse[AudioUploadResponse]: + """Upload an audio file for STT evaluation.""" + logger.info( + f"[upload_audio] Uploading audio file | " + f"project_id: {auth_context.project_.id}, filename: {file.filename}" + ) + + result = upload_audio_file( + session=_session, + file=file, + project_id=auth_context.project_.id, + ) + + return APIResponse.success_response(data=result) diff --git a/backend/app/api/routes/stt_evaluations/result.py b/backend/app/api/routes/stt_evaluations/result.py new file mode 100644 index 000000000..3fd994e25 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/result.py @@ -0,0 +1,134 @@ +"""STT result feedback API routes.""" + +import logging + +from fastapi import APIRouter, Body, Depends, HTTPException + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.stt_evaluations import ( + get_stt_result_by_id, + update_human_feedback, +) +from app.models.stt_evaluation import ( + STTFeedbackUpdate, + STTResultPublic, +) +from app.utils import APIResponse + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.patch( + "/results/{result_id}", + response_model=APIResponse[STTResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Update human feedback", + description=""" +Update human feedback on an STT transcription result. + +**Fields:** +- **is_correct**: Boolean indicating if the transcription is correct +- **comment**: Optional feedback comment explaining issues or observations +""", +) +def update_result_feedback( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, + feedback: STTFeedbackUpdate = Body(...), +) -> APIResponse[STTResultPublic]: + """Update human feedback on an STT result.""" + logger.info( + f"[update_result_feedback] Updating feedback | " + f"result_id: {result_id}, is_correct: {feedback.is_correct}" + ) + + # Verify result exists and belongs to this project + existing = get_stt_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not existing: + raise HTTPException(status_code=404, detail="Result not found") + + # Update feedback + result = update_human_feedback( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + is_correct=feedback.is_correct, + comment=feedback.comment, + ) + + return APIResponse.success_response( + data=STTResultPublic( + id=result.id, + transcription=result.transcription, + provider=result.provider, + status=result.status, + wer=result.wer, + cer=result.cer, + is_correct=result.is_correct, + comment=result.comment, + provider_metadata=result.provider_metadata, + error_message=result.error_message, + stt_sample_id=result.stt_sample_id, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + ) + + +@router.get( + "/results/{result_id}", + response_model=APIResponse[STTResultPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], + summary="Get STT result", + description="Get a single STT transcription result.", +) +def get_result( + _session: SessionDep, + auth_context: AuthContextDep, + result_id: int, +) -> APIResponse[STTResultPublic]: + """Get an STT result by ID.""" + result = get_stt_result_by_id( + session=_session, + result_id=result_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + return APIResponse.success_response( + data=STTResultPublic( + id=result.id, + transcription=result.transcription, + provider=result.provider, + status=result.status, + wer=result.wer, + cer=result.cer, + is_correct=result.is_correct, + comment=result.comment, + provider_metadata=result.provider_metadata, + error_message=result.error_message, + stt_sample_id=result.stt_sample_id, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + ) + ) diff --git a/backend/app/api/routes/stt_evaluations/router.py b/backend/app/api/routes/stt_evaluations/router.py new file mode 100644 index 000000000..07a958f17 --- /dev/null +++ b/backend/app/api/routes/stt_evaluations/router.py @@ -0,0 +1,13 @@ +"""Main router for STT evaluation API routes.""" + +from fastapi import APIRouter + +from . import dataset, evaluation, files, result + +router = APIRouter(prefix="/evaluations/stt", tags=["STT Evaluation"]) + +# Include all sub-routers +router.include_router(files.router) +router.include_router(dataset.router) +router.include_router(evaluation.router) +router.include_router(result.router) diff --git a/backend/app/celery/tasks/stt_evaluation.py b/backend/app/celery/tasks/stt_evaluation.py new file mode 100644 index 000000000..585431838 --- /dev/null +++ b/backend/app/celery/tasks/stt_evaluation.py @@ -0,0 +1,705 @@ +"""Celery tasks for STT evaluation processing.""" + +import logging +from typing import Any + +from asgi_correlation_id import correlation_id +from celery import current_task +from sqlmodel import Session, select + +from app.celery.celery_app import celery_app +from app.core.batch import GeminiBatchProvider +from app.core.batch.gemini import BatchJobState +from app.core.db import engine +from app.core.cloud.storage import get_cloud_storage +from app.crud.stt_evaluations import ( + get_stt_run_by_id, + get_samples_by_dataset_id, + get_stt_dataset_by_id, + update_stt_run, + create_stt_results, + update_stt_result, + count_results_by_status, +) +from app.models.stt_evaluation import STTResult, STTResultStatus +from app.services.stt_evaluations.gemini import ( + GeminiClient, + GeminiFilesManager, +) + +logger = logging.getLogger(__name__) + +# Default transcription prompt +DEFAULT_TRANSCRIPTION_PROMPT = ( + "Generate a verbatim transcript of the speech in this audio file. " + "Return only the transcription text without any formatting, timestamps, or metadata." +) + +# Provider name to Gemini model mapping +PROVIDER_MODEL_MAPPING: dict[str, str] = { + "gemini-2.5-pro": "models/gemini-2.5-pro", + "gemini-2.5-flash": "models/gemini-2.5-flash", + "gemini-2.0-flash": "models/gemini-2.0-flash", +} + +# Maximum number of polls (24 hours with 30s intervals) +MAX_POLL_COUNT = 2880 + + +def _get_model_for_provider(provider: str) -> str: + """Map provider name to Gemini model. + + Args: + provider: Provider name + + Returns: + str: Gemini model name + """ + return PROVIDER_MODEL_MAPPING.get(provider, f"models/{provider}") + + +def _build_batch_requests( + sample_file_mapping: list[tuple[int, int | None, str]], + prompt: str = DEFAULT_TRANSCRIPTION_PROMPT, +) -> list[dict[str, Any]]: + """Build JSONL batch request data from sample-file mappings. + + Each request follows the Gemini GenerateContentRequest format + with a text prompt and file_data reference. + + Args: + sample_file_mapping: List of (sample_id, result_id, google_file_uri) tuples + prompt: Transcription prompt + + Returns: + list[dict]: JSONL-compatible request dicts for GeminiBatchProvider + """ + return [ + { + "contents": [ + { + "parts": [ + {"text": prompt}, + {"file_data": {"file_uri": file_uri}}, + ], + "role": "user", + } + ], + } + for _, _, file_uri in sample_file_mapping + ] + + +@celery_app.task(bind=True, queue="low_priority") +def process_stt_evaluation( + self, + evaluation_run_id: int, + org_id: int, + project_id: int, + trace_id: str, + **kwargs, +) -> dict[str, Any]: + """Process an STT evaluation run. + + This task: + 1. Fetches the evaluation run and dataset samples + 2. Uploads audio files to Google Files API + 3. Creates batch requests via GeminiBatchProvider + 4. Submits batch jobs per provider + 5. Schedules polling task + + Args: + evaluation_run_id: ID of the evaluation run + org_id: Organization ID + project_id: Project ID + trace_id: Correlation/trace ID + + Returns: + dict: Result with batch job information + """ + task_id = current_task.request.id + correlation_id.set(trace_id) + + logger.info( + f"[process_stt_evaluation] Starting STT evaluation | " + f"run_id: {evaluation_run_id}, task_id: {task_id}" + ) + + try: + with Session(engine) as session: + # Get the evaluation run + run = get_stt_run_by_id( + session=session, + run_id=evaluation_run_id, + org_id=org_id, + project_id=project_id, + ) + + if not run: + logger.error( + f"[process_stt_evaluation] Run not found | run_id: {evaluation_run_id}" + ) + return {"success": False, "error": "Evaluation run not found"} + + # Get the dataset + dataset = get_stt_dataset_by_id( + session=session, + dataset_id=run.dataset_id, + org_id=org_id, + project_id=project_id, + ) + + if not dataset: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message="Dataset not found", + ) + return {"success": False, "error": "Dataset not found"} + + # Get all samples + samples, total = get_samples_by_dataset_id( + session=session, + dataset_id=run.dataset_id, + org_id=org_id, + project_id=project_id, + limit=10000, # Get all samples + ) + + if not samples: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message="No samples in dataset", + ) + return {"success": False, "error": "No samples in dataset"} + + # Update run with total items + providers = run.providers or ["gemini-2.5-pro"] + total_items = len(samples) * len(providers) + + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="processing", + ) + + # Create result records for each sample and provider + results = create_stt_results( + session=session, + samples=samples, + evaluation_run_id=evaluation_run_id, + org_id=org_id, + project_id=project_id, + providers=providers, + ) + + # Extract result data before session closes to avoid DetachedInstanceError + result_refs = [ + {"id": r.id, "stt_sample_id": r.stt_sample_id, "provider": r.provider} + for r in results + ] + + # Update total items + run.total_items = total_items + session.add(run) + session.commit() + + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + + # Get cloud storage for S3 access + storage = get_cloud_storage(session=session, project_id=project_id) + + # Upload audio files to Google Files API + files_manager = GeminiFilesManager(gemini_client.client) + + sample_file_mapping = [] # [(sample_id, result_id, google_file_uri)] + + for sample in samples: + try: + # Get signed URL for S3 audio file + signed_url = storage.get_signed_url( + sample.object_store_url, expires_in=3600 + ) + + # Extract filename from URL + filename = sample.object_store_url.split("/")[-1] + + # Upload to Google Files API + google_file_uri = files_manager.upload_from_url( + signed_url=signed_url, + filename=filename, + ) + + # Find the result record for this sample + result_for_sample = next( + (r for r in result_refs if r["stt_sample_id"] == sample.id), + None, + ) + + sample_file_mapping.append( + ( + sample.id, + result_for_sample["id"] if result_for_sample else None, + google_file_uri, + ) + ) + + logger.info( + f"[process_stt_evaluation] Uploaded audio to Google | " + f"sample_id: {sample.id}, file_uri: {google_file_uri}" + ) + + except Exception as e: + logger.error( + f"[process_stt_evaluation] Failed to upload audio | " + f"sample_id: {sample.id}, error: {str(e)}" + ) + # Mark result as failed + for ref in result_refs: + if ref["stt_sample_id"] == sample.id: + update_stt_result( + session=session, + result_id=ref["id"], + status=STTResultStatus.FAILED.value, + error_message=f"Failed to upload audio: {str(e)}", + ) + + # Build batch requests from uploaded files + jsonl_data = _build_batch_requests(sample_file_mapping) + + # Process each provider using GeminiBatchProvider + batch_jobs: dict[str, str] = {} + + for provider in providers: + try: + model = _get_model_for_provider(provider) + batch_provider = GeminiBatchProvider( + client=gemini_client.client, model=model + ) + + batch_result = batch_provider.create_batch( + jsonl_data=jsonl_data, + config={ + "display_name": f"stt-eval-{evaluation_run_id}-{provider}", + "model": model, + }, + ) + + batch_jobs[provider] = batch_result["provider_batch_id"] + + logger.info( + f"[process_stt_evaluation] Batch job submitted | " + f"run_id: {evaluation_run_id}, provider: {provider}, " + f"batch_id: {batch_result['provider_batch_id']}" + ) + + except Exception as e: + logger.error( + f"[process_stt_evaluation] Failed to submit batch | " + f"provider: {provider}, error: {str(e)}" + ) + # Update results for this provider as failed + with Session(engine) as session: + for ref in result_refs: + if ref["provider"] == provider: + update_stt_result( + session=session, + result_id=ref["id"], + status=STTResultStatus.FAILED.value, + error_message=f"Batch submission failed: {str(e)}", + ) + + if not batch_jobs: + with Session(engine) as session: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message="All batch submissions failed", + ) + return {"success": False, "error": "All batch submissions failed"} + + # Store batch job info in run score for polling + with Session(engine) as session: + update_stt_run( + session=session, + run_id=evaluation_run_id, + score={ + "batch_jobs": batch_jobs, + "sample_file_mapping": [ + {"sample_id": s, "result_id": r, "file_uri": f} + for s, r, f in sample_file_mapping + ], + }, + ) + + # Schedule polling task + poll_stt_batch_status.apply_async( + kwargs={ + "evaluation_run_id": evaluation_run_id, + "org_id": org_id, + "project_id": project_id, + "trace_id": trace_id, + "batch_jobs": batch_jobs, + }, + countdown=30, # Wait 30 seconds before first poll + ) + + return { + "success": True, + "run_id": evaluation_run_id, + "batch_jobs": batch_jobs, + "sample_count": len(samples), + } + + except Exception as e: + logger.error( + f"[process_stt_evaluation] Failed to process evaluation | " + f"run_id: {evaluation_run_id}, error: {str(e)}", + exc_info=True, + ) + + with Session(engine) as session: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message=str(e), + ) + + return {"success": False, "error": str(e)} + + +@celery_app.task(bind=True, queue="low_priority") +def poll_stt_batch_status( + self, + evaluation_run_id: int, + org_id: int, + project_id: int, + trace_id: str, + batch_jobs: dict[str, str], + poll_count: int = 0, + **kwargs, +) -> dict[str, Any]: + """Poll Gemini batch job status using GeminiBatchProvider. + + Args: + evaluation_run_id: ID of the evaluation run + org_id: Organization ID + project_id: Project ID + trace_id: Correlation/trace ID + batch_jobs: Dict of provider -> batch_id + poll_count: Number of times we've polled + + Returns: + dict: Status information + """ + correlation_id.set(trace_id) + + if poll_count >= MAX_POLL_COUNT: + logger.error( + f"[poll_stt_batch_status] Polling timed out | " + f"run_id: {evaluation_run_id}" + ) + with Session(engine) as session: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message="Batch processing timed out after 24 hours", + ) + return {"success": False, "error": "Timeout"} + + logger.info( + f"[poll_stt_batch_status] Polling batch status | " + f"run_id: {evaluation_run_id}, poll_count: {poll_count}" + ) + + try: + with Session(engine) as session: + # Initialize Gemini client and batch provider + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + all_complete = True + any_success = False + + for provider, batch_id in batch_jobs.items(): + status = batch_provider.get_batch_status(batch_id) + provider_status = status["provider_status"] + + logger.info( + f"[poll_stt_batch_status] Batch status | " + f"provider: {provider}, batch_id: {batch_id}, " + f"state: {provider_status}" + ) + + is_terminal = provider_status in { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, + } + + if not is_terminal: + all_complete = False + elif provider_status == BatchJobState.SUCCEEDED.value: + any_success = True + + if not all_complete: + # Re-schedule polling + poll_stt_batch_status.apply_async( + kwargs={ + "evaluation_run_id": evaluation_run_id, + "org_id": org_id, + "project_id": project_id, + "trace_id": trace_id, + "batch_jobs": batch_jobs, + "poll_count": poll_count + 1, + }, + countdown=30, + ) + return {"success": True, "status": "polling", "poll_count": poll_count} + + # All batches complete - process results + if any_success: + process_stt_batch_results.apply_async( + kwargs={ + "evaluation_run_id": evaluation_run_id, + "org_id": org_id, + "project_id": project_id, + "trace_id": trace_id, + "batch_jobs": batch_jobs, + }, + ) + return {"success": True, "status": "processing_results"} + else: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message="All batch jobs failed", + ) + return {"success": False, "error": "All batch jobs failed"} + + except Exception as e: + logger.error( + f"[poll_stt_batch_status] Polling failed | " + f"run_id: {evaluation_run_id}, error: {str(e)}", + exc_info=True, + ) + + # Re-schedule polling (might be temporary error) + if poll_count < MAX_POLL_COUNT: + poll_stt_batch_status.apply_async( + kwargs={ + "evaluation_run_id": evaluation_run_id, + "org_id": org_id, + "project_id": project_id, + "trace_id": trace_id, + "batch_jobs": batch_jobs, + "poll_count": poll_count + 1, + }, + countdown=60, # Wait longer on error + ) + + return {"success": False, "error": str(e)} + + +@celery_app.task(bind=True, queue="low_priority") +def process_stt_batch_results( + self, + evaluation_run_id: int, + org_id: int, + project_id: int, + trace_id: str, + batch_jobs: dict[str, str], + **kwargs, +) -> dict[str, Any]: + """Process results from completed Gemini batch jobs using GeminiBatchProvider. + + Args: + evaluation_run_id: ID of the evaluation run + org_id: Organization ID + project_id: Project ID + trace_id: Correlation/trace ID + batch_jobs: Dict of provider -> batch_id + + Returns: + dict: Processing result + """ + correlation_id.set(trace_id) + + logger.info( + f"[process_stt_batch_results] Processing batch results | " + f"run_id: {evaluation_run_id}" + ) + + try: + with Session(engine) as session: + # Get the run to access sample mapping + run = get_stt_run_by_id( + session=session, + run_id=evaluation_run_id, + org_id=org_id, + project_id=project_id, + ) + + if not run or not run.score: + logger.error( + f"[process_stt_batch_results] Run or score not found | " + f"run_id: {evaluation_run_id}" + ) + return {"success": False, "error": "Run data not found"} + + sample_file_mapping = run.score.get("sample_file_mapping", []) + sample_ids = [item["sample_id"] for item in sample_file_mapping] + + # Initialize Gemini client and providers + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + files_manager = GeminiFilesManager(gemini_client.client) + + processed_count = 0 + failed_count = 0 + + for provider, batch_id in batch_jobs.items(): + try: + # Use GeminiBatchProvider to download results + results = batch_provider.download_batch_results(batch_id) + + logger.info( + f"[process_stt_batch_results] Got batch results | " + f"provider: {provider}, result_count: {len(results)}" + ) + + # Match results to samples by index + for batch_result in results: + custom_id = batch_result["custom_id"] + # custom_id is the index as string + try: + index = int(custom_id) + except (ValueError, TypeError): + index = results.index(batch_result) + + if index >= len(sample_ids): + continue + + sample_id = sample_ids[index] + + # Find result record for this sample and provider + stmt = select(STTResult).where( + STTResult.evaluation_run_id == evaluation_run_id, + STTResult.stt_sample_id == sample_id, + STTResult.provider == provider, + ) + result_record = session.exec(stmt).one_or_none() + + if result_record: + if batch_result.get("response"): + text = batch_result["response"].get("text", "") + update_stt_result( + session=session, + result_id=result_record.id, + transcription=text, + status=STTResultStatus.COMPLETED.value, + ) + processed_count += 1 + else: + update_stt_result( + session=session, + result_id=result_record.id, + status=STTResultStatus.FAILED.value, + error_message=batch_result.get( + "error", "Unknown error" + ), + ) + failed_count += 1 + + except Exception as e: + logger.error( + f"[process_stt_batch_results] Failed to process provider results | " + f"provider: {provider}, error: {str(e)}" + ) + failed_count += len(sample_file_mapping) + + # Clean up Google Files + for item in sample_file_mapping: + file_uri = item.get("file_uri") + if file_uri: + try: + files_manager.delete_file(file_uri) + except Exception as e: + logger.warning( + f"[process_stt_batch_results] Failed to delete Google file | " + f"file_uri: {file_uri}, error: {str(e)}" + ) + + # Update run status + status_counts = count_results_by_status( + session=session, run_id=evaluation_run_id + ) + + completed = status_counts.get(STTResultStatus.COMPLETED.value, 0) + failed = status_counts.get(STTResultStatus.FAILED.value, 0) + pending = status_counts.get(STTResultStatus.PENDING.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = None + if failed > 0: + error_message = f"{failed} transcription(s) failed" + + update_stt_run( + session=session, + run_id=evaluation_run_id, + status=final_status, + processed_samples=completed + failed, + error_message=error_message, + ) + + logger.info( + f"[process_stt_batch_results] Batch results processed | " + f"run_id: {evaluation_run_id}, completed: {completed}, " + f"failed: {failed}, status: {final_status}" + ) + + return { + "success": True, + "run_id": evaluation_run_id, + "completed": completed, + "failed": failed, + "status": final_status, + } + + except Exception as e: + logger.error( + f"[process_stt_batch_results] Failed to process results | " + f"run_id: {evaluation_run_id}, error: {str(e)}", + exc_info=True, + ) + + with Session(engine) as session: + update_stt_run( + session=session, + run_id=evaluation_run_id, + status="failed", + error_message=f"Result processing failed: {str(e)}", + ) + + return {"success": False, "error": str(e)} diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py new file mode 100644 index 000000000..45d751c87 --- /dev/null +++ b/backend/app/core/batch/gemini.py @@ -0,0 +1,348 @@ +"""Gemini batch provider implementation.""" + +import json +import logging +import os +import tempfile +import time +from enum import Enum +from typing import Any + +from google import genai +from google.genai import types + +from .base import BatchProvider + +logger = logging.getLogger(__name__) + + +class BatchJobState(str, Enum): + """Gemini batch job states.""" + + PENDING = "JOB_STATE_PENDING" + RUNNING = "JOB_STATE_RUNNING" + SUCCEEDED = "JOB_STATE_SUCCEEDED" + FAILED = "JOB_STATE_FAILED" + CANCELLED = "JOB_STATE_CANCELLED" + EXPIRED = "JOB_STATE_EXPIRED" + + +# Terminal states that indicate the batch is done +_TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + +# Failed terminal states +_FAILED_STATES = { + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +class GeminiBatchProvider(BatchProvider): + """Gemini implementation of the BatchProvider interface. + + Supports both inline requests and JSONL file-based batch submissions. + Each JSONL line follows the Gemini format: + {"key": "request-1", "request": {"contents": [{"parts": [...]}]}} + """ + + DEFAULT_MODEL = "models/gemini-2.5-pro" + + def __init__(self, client: genai.Client, model: str | None = None): + """Initialize the Gemini batch provider. + + Args: + client: Configured Gemini client + model: Model to use (defaults to gemini-2.5-pro) + """ + self._client = client + self._model = model or self.DEFAULT_MODEL + + def create_batch( + self, jsonl_data: list[dict[str, Any]], config: dict[str, Any] + ) -> dict[str, Any]: + """Upload JSONL data and create a batch job with Gemini. + + Args: + jsonl_data: List of dictionaries representing JSONL lines. + Each dict should be a valid GenerateContentRequest, e.g.: + {"contents": [{"parts": [{"text": "..."}]}]} + config: Provider-specific configuration with: + - display_name: Optional batch display name + - model: Optional model override + + Returns: + Dictionary containing: + - provider_batch_id: Gemini batch job name + - provider_file_id: Uploaded JSONL file name (or None for inline) + - provider_status: Initial status from Gemini + - total_items: Number of items in the batch + """ + model = config.get("model", self._model) + display_name = config.get("display_name", f"batch-{int(time.time())}") + + logger.info( + f"[create_batch] Creating Gemini batch | items={len(jsonl_data)} | " + f"model={model} | display_name={display_name}" + ) + + try: + # Use inline requests for the batch + batch_job = self._client.batches.create( + model=model, + src=jsonl_data, + config={"display_name": display_name}, + ) + + initial_state = batch_job.state.name if batch_job.state else "UNKNOWN" + + result = { + "provider_batch_id": batch_job.name, + "provider_file_id": None, + "provider_status": initial_state, + "total_items": len(jsonl_data), + } + + logger.info( + f"[create_batch] Created Gemini batch | batch_id={batch_job.name} | " + f"status={initial_state} | items={len(jsonl_data)}" + ) + + return result + + except Exception as e: + logger.error(f"[create_batch] Failed to create Gemini batch | {e}") + raise + + def get_batch_status(self, batch_id: str) -> dict[str, Any]: + """Poll Gemini for batch job status. + + Args: + batch_id: Gemini batch job name + + Returns: + Dictionary containing: + - provider_status: Current Gemini state + - provider_output_file_id: batch_id (used to fetch results) + - error_message: Error message (if failed) + """ + logger.info( + f"[get_batch_status] Polling Gemini batch status | batch_id={batch_id}" + ) + + try: + batch_job = self._client.batches.get(name=batch_id) + state = batch_job.state.name if batch_job.state else "UNKNOWN" + + result: dict[str, Any] = { + "provider_status": state, + # Gemini uses the same batch name to fetch results + "provider_output_file_id": batch_id, + } + + if state in _FAILED_STATES: + result["error_message"] = f"Batch {state}" + + logger.info( + f"[get_batch_status] Gemini batch status | batch_id={batch_id} | " + f"status={state}" + ) + + return result + + except Exception as e: + logger.error( + f"[get_batch_status] Failed to poll Gemini batch status | " + f"batch_id={batch_id} | {e}" + ) + raise + + def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: + """Download and parse batch results from Gemini. + + Gemini returns results either as inlined responses or as a + downloadable JSONL file. This method handles both formats and + normalizes the output to match the BatchProvider interface. + + Args: + output_file_id: Gemini batch job name (used to fetch the batch) + + Returns: + List of result dictionaries, each containing: + - custom_id: Item key from input (or index as string) + - response: Dict with "text" key containing the generated text + - error: Error info (if item failed), None otherwise + """ + logger.info( + f"[download_batch_results] Downloading Gemini batch results | " + f"batch_id={output_file_id}" + ) + + try: + batch_job = self._client.batches.get(name=output_file_id) + state = batch_job.state.name if batch_job.state else "UNKNOWN" + + if state != BatchJobState.SUCCEEDED.value: + raise ValueError(f"Batch job not complete. Current state: {state}") + + results: list[dict[str, Any]] = [] + + # Handle inline responses + if batch_job.dest and batch_job.dest.inlined_responses: + for i, response in enumerate(batch_job.dest.inlined_responses): + if response.response: + text = self._extract_text_from_response(response.response) + results.append( + { + "custom_id": str(i), + "response": {"text": text}, + "error": None, + } + ) + elif response.error: + results.append( + { + "custom_id": str(i), + "response": None, + "error": str(response.error), + } + ) + + # Handle file-based results + elif ( + batch_job.dest + and hasattr(batch_job.dest, "file_name") + and batch_job.dest.file_name + ): + file_content = self.download_file(batch_job.dest.file_name) + lines = file_content.strip().split("\n") + for i, line in enumerate(lines): + try: + parsed = json.loads(line) + text = parsed.get("response", {}).get("text", "") + custom_id = parsed.get("key", str(i)) + results.append( + { + "custom_id": custom_id, + "response": {"text": text}, + "error": None, + } + ) + except json.JSONDecodeError as e: + logger.error( + f"[download_batch_results] Failed to parse JSON | " + f"line={i + 1} | {e}" + ) + continue + + logger.info( + f"[download_batch_results] Downloaded Gemini batch results | " + f"batch_id={output_file_id} | results={len(results)}" + ) + + return results + + except Exception as e: + logger.error( + f"[download_batch_results] Failed to download Gemini batch results | " + f"batch_id={output_file_id} | {e}" + ) + raise + + def upload_file(self, content: str, purpose: str = "batch") -> str: + """Upload a JSONL file to Gemini Files API. + + Args: + content: File content (JSONL string) + purpose: Purpose of the file (unused for Gemini, kept for interface) + + Returns: + Gemini file name (e.g., "files/xxx") + """ + logger.info(f"[upload_file] Uploading file to Gemini | bytes={len(content)}") + + try: + with tempfile.NamedTemporaryFile( + suffix=".jsonl", delete=False, mode="w", encoding="utf-8" + ) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + try: + uploaded_file = self._client.files.upload( + file=tmp_path, + config=types.UploadFileConfig( + display_name=f"batch-input-{int(time.time())}", + mime_type="jsonl", + ), + ) + + logger.info( + f"[upload_file] Uploaded file to Gemini | " + f"file_name={uploaded_file.name}" + ) + + return uploaded_file.name + + finally: + os.unlink(tmp_path) + + except Exception as e: + logger.error(f"[upload_file] Failed to upload file to Gemini | {e}") + raise + + def download_file(self, file_id: str) -> str: + """Download a file from Gemini Files API. + + Args: + file_id: Gemini file name (e.g., "files/xxx") + + Returns: + File content as UTF-8 string + """ + logger.info(f"[download_file] Downloading file from Gemini | file_id={file_id}") + + try: + file_content = self._client.files.download(file=file_id) + content = file_content.decode("utf-8") + + logger.info( + f"[download_file] Downloaded file from Gemini | " + f"file_id={file_id} | bytes={len(content)}" + ) + + return content + + except Exception as e: + logger.error( + f"[download_file] Failed to download file from Gemini | " + f"file_id={file_id} | {e}" + ) + raise + + @staticmethod + def _extract_text_from_response(response: Any) -> str: + """Extract text content from a Gemini response object. + + Args: + response: Gemini GenerateContentResponse + + Returns: + str: Extracted text + """ + if hasattr(response, "text"): + return response.text + + text = "" + if hasattr(response, "candidates"): + for candidate in response.candidates: + if hasattr(candidate, "content"): + for part in candidate.content.parts: + if hasattr(part, "text"): + text += part.text + return text diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py new file mode 100644 index 000000000..6e883936c --- /dev/null +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -0,0 +1,42 @@ +"""STT Evaluation CRUD operations.""" + +from .dataset import ( + create_stt_dataset, + create_stt_samples, + get_stt_dataset_by_id, + list_stt_datasets, + get_samples_by_dataset_id, +) +from .run import ( + create_stt_run, + get_stt_run_by_id, + list_stt_runs, + update_stt_run, +) +from .result import ( + create_stt_results, + get_stt_result_by_id, + get_results_by_run_id, + update_stt_result, + update_human_feedback, +) + +__all__ = [ + # Dataset + "create_stt_dataset", + "create_stt_samples", + "get_stt_dataset_by_id", + "list_stt_datasets", + "get_samples_by_dataset_id", + # Run + "create_stt_run", + "get_stt_run_by_id", + "list_stt_runs", + "update_stt_run", + # Result + "create_stt_results", + "get_stt_result_by_id", + "get_results_by_run_id", + "update_stt_result", + "update_human_feedback", +] diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py new file mode 100644 index 000000000..35b8f6b80 --- /dev/null +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -0,0 +1,341 @@ +"""CRUD operations for STT evaluation datasets and samples.""" + +import logging +from typing import Any + +from sqlalchemy.exc import IntegrityError +from sqlmodel import Session, select, func + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models import EvaluationDataset +from app.models.stt_evaluation import ( + EvaluationType, + STTSample, + STTSampleCreate, + STTDatasetPublic, + STTSamplePublic, +) + +logger = logging.getLogger(__name__) + + +def create_stt_dataset( + *, + session: Session, + name: str, + org_id: int, + project_id: int, + description: str | None = None, + language: str | None = None, + dataset_metadata: dict[str, Any] | None = None, +) -> EvaluationDataset: + """Create a new STT evaluation dataset. + + Args: + session: Database session + name: Dataset name + org_id: Organization ID + project_id: Project ID + description: Optional description + language: Optional default language code + dataset_metadata: Optional metadata dict + + Returns: + EvaluationDataset: Created dataset + + Raises: + HTTPException: If dataset with same name already exists + """ + logger.info( + f"[create_stt_dataset] Creating STT dataset | " + f"name: {name}, org_id: {org_id}, project_id: {project_id}" + ) + + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.STT.value, + language=language, + dataset_metadata=dataset_metadata or {}, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + try: + session.add(dataset) + session.commit() + session.refresh(dataset) + + logger.info( + f"[create_stt_dataset] STT dataset created | " + f"dataset_id: {dataset.id}, name: {name}" + ) + + return dataset + + except IntegrityError as e: + session.rollback() + if "uq_evaluation_dataset_name_org_project" in str(e): + logger.error( + f"[create_stt_dataset] Dataset name already exists | name: {name}" + ) + raise HTTPException( + status_code=400, + detail=f"Dataset with name '{name}' already exists", + ) + raise + + +def create_stt_samples( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, + samples: list[STTSampleCreate], +) -> list[STTSample]: + """Create STT samples for a dataset. + + Args: + session: Database session + dataset_id: Parent dataset ID + org_id: Organization ID + project_id: Project ID + samples: List of sample data + + Returns: + list[STTSample]: Created samples + """ + logger.info( + f"[create_stt_samples] Creating STT samples | " + f"dataset_id: {dataset_id}, sample_count: {len(samples)}" + ) + + created_samples = [] + + for sample_data in samples: + sample = STTSample( + object_store_url=sample_data.object_store_url, + language=sample_data.language, + ground_truth=sample_data.ground_truth, + duration_seconds=sample_data.duration_seconds, + sample_metadata=sample_data.sample_metadata or {}, + dataset_id=dataset_id, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + session.add(sample) + created_samples.append(sample) + + session.commit() + + # Refresh all samples to get IDs + for sample in created_samples: + session.refresh(sample) + + logger.info( + f"[create_stt_samples] STT samples created | " + f"dataset_id: {dataset_id}, created_count: {len(created_samples)}" + ) + + return created_samples + + +def get_stt_dataset_by_id( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, +) -> EvaluationDataset | None: + """Get an STT dataset by ID. + + Args: + session: Database session + dataset_id: Dataset ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationDataset | None: Dataset if found + """ + statement = select(EvaluationDataset).where( + EvaluationDataset.id == dataset_id, + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.STT.value, + ) + + return session.exec(statement).one_or_none() + + +def list_stt_datasets( + *, + session: Session, + org_id: int, + project_id: int, + limit: int = 50, + offset: int = 0, +) -> tuple[list[STTDatasetPublic], int]: + """List STT datasets for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTDatasetPublic], int]: Datasets and total count + """ + # Get total count + count_stmt = select(func.count(EvaluationDataset.id)).where( + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.STT.value, + ) + total = session.exec(count_stmt).one() + + # Get datasets + statement = ( + select(EvaluationDataset) + .where( + EvaluationDataset.organization_id == org_id, + EvaluationDataset.project_id == project_id, + EvaluationDataset.type == EvaluationType.STT.value, + ) + .order_by(EvaluationDataset.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + datasets = session.exec(statement).all() + + # Convert to public models with sample counts + result = [] + for dataset in datasets: + sample_count = get_sample_count_for_dataset( + session=session, dataset_id=dataset.id + ) + result.append( + STTDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language=dataset.language, + dataset_metadata=dataset.dataset_metadata, + sample_count=sample_count, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, + ) + ) + + return result, total + + +def get_sample_count_for_dataset(*, session: Session, dataset_id: int) -> int: + """Get the number of samples in a dataset. + + Args: + session: Database session + dataset_id: Dataset ID + + Returns: + int: Sample count + """ + statement = select(func.count(STTSample.id)).where( + STTSample.dataset_id == dataset_id + ) + return session.exec(statement).one() + + +def get_samples_by_dataset_id( + *, + session: Session, + dataset_id: int, + org_id: int, + project_id: int, + limit: int = 100, + offset: int = 0, +) -> tuple[list[STTSample], int]: + """Get samples for a dataset. + + Args: + session: Database session + dataset_id: Dataset ID + org_id: Organization ID + project_id: Project ID + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTSample], int]: Samples and total count + """ + # Get total count + count_stmt = select(func.count(STTSample.id)).where( + STTSample.dataset_id == dataset_id, + STTSample.organization_id == org_id, + STTSample.project_id == project_id, + ) + total = session.exec(count_stmt).one() + + # Get samples + statement = ( + select(STTSample) + .where( + STTSample.dataset_id == dataset_id, + STTSample.organization_id == org_id, + STTSample.project_id == project_id, + ) + .order_by(STTSample.id) + .offset(offset) + .limit(limit) + ) + + samples = session.exec(statement).all() + + return list(samples), total + + +def update_dataset_metadata( + *, + session: Session, + dataset_id: int, + metadata: dict[str, Any], +) -> EvaluationDataset | None: + """Update dataset metadata. + + Args: + session: Database session + dataset_id: Dataset ID + metadata: Metadata to merge + + Returns: + EvaluationDataset | None: Updated dataset + """ + statement = select(EvaluationDataset).where(EvaluationDataset.id == dataset_id) + dataset = session.exec(statement).one_or_none() + + if not dataset: + return None + + # Merge metadata + current_metadata = dataset.dataset_metadata or {} + current_metadata.update(metadata) + dataset.dataset_metadata = current_metadata + dataset.updated_at = now() + + session.add(dataset) + session.commit() + session.refresh(dataset) + + return dataset diff --git a/backend/app/crud/stt_evaluations/result.py b/backend/app/crud/stt_evaluations/result.py new file mode 100644 index 000000000..428f78529 --- /dev/null +++ b/backend/app/crud/stt_evaluations/result.py @@ -0,0 +1,413 @@ +"""CRUD operations for STT evaluation results.""" + +import logging +from typing import Any + +from sqlmodel import Session, select, func + +from app.core.exception_handlers import HTTPException +from app.core.util import now +from app.models.stt_evaluation import ( + STTResult, + STTResultStatus, + STTResultPublic, + STTSample, + STTSamplePublic, + STTResultWithSample, +) + +logger = logging.getLogger(__name__) + + +def create_stt_result( + *, + session: Session, + stt_sample_id: int, + evaluation_run_id: int, + org_id: int, + project_id: int, + provider: str, + status: str = STTResultStatus.PENDING.value, +) -> STTResult: + """Create a single STT result record. + + Args: + session: Database session + stt_sample_id: Sample ID + evaluation_run_id: Run ID + org_id: Organization ID + project_id: Project ID + provider: Provider name + status: Initial status + + Returns: + STTResult: Created result + """ + result = STTResult( + stt_sample_id=stt_sample_id, + evaluation_run_id=evaluation_run_id, + organization_id=org_id, + project_id=project_id, + provider=provider, + status=status, + inserted_at=now(), + updated_at=now(), + ) + + session.add(result) + session.commit() + session.refresh(result) + + return result + + +def create_stt_results( + *, + session: Session, + samples: list[STTSample], + evaluation_run_id: int, + org_id: int, + project_id: int, + providers: list[str], +) -> list[STTResult]: + """Create STT result records for all samples and providers. + + Creates one result per sample per provider. + + Args: + session: Database session + samples: List of samples + evaluation_run_id: Run ID + org_id: Organization ID + project_id: Project ID + providers: List of providers + + Returns: + list[STTResult]: Created results + """ + logger.info( + f"[create_stt_results] Creating STT results | " + f"run_id: {evaluation_run_id}, sample_count: {len(samples)}, " + f"provider_count: {len(providers)}" + ) + + results = [] + + for sample in samples: + for provider in providers: + result = STTResult( + stt_sample_id=sample.id, + evaluation_run_id=evaluation_run_id, + organization_id=org_id, + project_id=project_id, + provider=provider, + status=STTResultStatus.PENDING.value, + inserted_at=now(), + updated_at=now(), + ) + session.add(result) + results.append(result) + + session.commit() + + # Refresh to get IDs + for result in results: + session.refresh(result) + + logger.info( + f"[create_stt_results] STT results created | " + f"run_id: {evaluation_run_id}, result_count: {len(results)}" + ) + + return results + + +def get_stt_result_by_id( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, +) -> STTResult | None: + """Get an STT result by ID. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + + Returns: + STTResult | None: Result if found + """ + statement = select(STTResult).where( + STTResult.id == result_id, + STTResult.organization_id == org_id, + STTResult.project_id == project_id, + ) + + return session.exec(statement).one_or_none() + + +def get_results_by_run_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, + provider: str | None = None, + status: str | None = None, + limit: int = 100, + offset: int = 0, +) -> tuple[list[STTResultWithSample], int]: + """Get results for an evaluation run with sample data. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + provider: Optional filter by provider + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTResultWithSample], int]: Results with samples and total count + """ + # Build where clause + where_clauses = [ + STTResult.evaluation_run_id == run_id, + STTResult.organization_id == org_id, + STTResult.project_id == project_id, + ] + + if provider is not None: + where_clauses.append(STTResult.provider == provider) + + if status is not None: + where_clauses.append(STTResult.status == status) + + # Get total count + count_stmt = select(func.count(STTResult.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + # Get results with samples + statement = ( + select(STTResult, STTSample) + .join(STTSample, STTResult.stt_sample_id == STTSample.id) + .where(*where_clauses) + .order_by(STTResult.id) + .offset(offset) + .limit(limit) + ) + + rows = session.exec(statement).all() + + # Convert to response models + results = [] + for result, sample in rows: + sample_public = STTSamplePublic( + id=sample.id, + object_store_url=sample.object_store_url, + language=sample.language, + ground_truth=sample.ground_truth, + duration_seconds=sample.duration_seconds, + sample_metadata=sample.sample_metadata, + dataset_id=sample.dataset_id, + organization_id=sample.organization_id, + project_id=sample.project_id, + inserted_at=sample.inserted_at, + updated_at=sample.updated_at, + ) + + result_with_sample = STTResultWithSample( + id=result.id, + transcription=result.transcription, + provider=result.provider, + status=result.status, + wer=result.wer, + cer=result.cer, + is_correct=result.is_correct, + comment=result.comment, + provider_metadata=result.provider_metadata, + error_message=result.error_message, + stt_sample_id=result.stt_sample_id, + evaluation_run_id=result.evaluation_run_id, + organization_id=result.organization_id, + project_id=result.project_id, + inserted_at=result.inserted_at, + updated_at=result.updated_at, + sample=sample_public, + ) + results.append(result_with_sample) + + return results, total + + +def update_stt_result( + *, + session: Session, + result_id: int, + transcription: str | None = None, + status: str | None = None, + wer: float | None = None, + cer: float | None = None, + provider_metadata: dict[str, Any] | None = None, + error_message: str | None = None, +) -> STTResult | None: + """Update an STT result with transcription data. + + Args: + session: Database session + result_id: Result ID + transcription: Generated transcription + status: New status + wer: Word Error Rate + cer: Character Error Rate + provider_metadata: Provider response metadata + error_message: Error message if failed + + Returns: + STTResult | None: Updated result + """ + statement = select(STTResult).where(STTResult.id == result_id) + result = session.exec(statement).one_or_none() + + if not result: + return None + + if transcription is not None: + result.transcription = transcription + + if status is not None: + result.status = status + + if wer is not None: + result.wer = wer + + if cer is not None: + result.cer = cer + + if provider_metadata is not None: + result.provider_metadata = provider_metadata + + if error_message is not None: + result.error_message = error_message + + result.updated_at = now() + + session.add(result) + session.commit() + session.refresh(result) + + return result + + +def update_human_feedback( + *, + session: Session, + result_id: int, + org_id: int, + project_id: int, + is_correct: bool | None = None, + comment: str | None = None, +) -> STTResult | None: + """Update human feedback on an STT result. + + Args: + session: Database session + result_id: Result ID + org_id: Organization ID + project_id: Project ID + is_correct: Human verification of correctness + comment: Feedback comment + + Returns: + STTResult | None: Updated result + + Raises: + HTTPException: If result not found + """ + result = get_stt_result_by_id( + session=session, + result_id=result_id, + org_id=org_id, + project_id=project_id, + ) + + if not result: + raise HTTPException(status_code=404, detail="Result not found") + + if is_correct is not None: + result.is_correct = is_correct + + if comment is not None: + result.comment = comment + + result.updated_at = now() + + session.add(result) + session.commit() + session.refresh(result) + + logger.info( + f"[update_human_feedback] Human feedback updated | " + f"result_id: {result_id}, is_correct: {is_correct}" + ) + + return result + + +def get_pending_results_for_run( + *, + session: Session, + run_id: int, + provider: str | None = None, +) -> list[STTResult]: + """Get all pending results for a run. + + Args: + session: Database session + run_id: Run ID + provider: Optional filter by provider + + Returns: + list[STTResult]: Pending results + """ + where_clauses = [ + STTResult.evaluation_run_id == run_id, + STTResult.status == STTResultStatus.PENDING.value, + ] + + if provider is not None: + where_clauses.append(STTResult.provider == provider) + + statement = select(STTResult).where(*where_clauses) + + return list(session.exec(statement).all()) + + +def count_results_by_status( + *, + session: Session, + run_id: int, +) -> dict[str, int]: + """Count results by status for a run. + + Args: + session: Database session + run_id: Run ID + + Returns: + dict[str, int]: Counts by status + """ + statement = ( + select(STTResult.status, func.count(STTResult.id)) + .where(STTResult.evaluation_run_id == run_id) + .group_by(STTResult.status) + ) + + rows = session.exec(statement).all() + + return {status: count for status, count in rows} diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py new file mode 100644 index 000000000..d67ba07f9 --- /dev/null +++ b/backend/app/crud/stt_evaluations/run.py @@ -0,0 +1,302 @@ +"""CRUD operations for STT evaluation runs.""" + +import logging +from typing import Any + +from sqlmodel import Session, select, func + +from app.core.util import now +from app.models import EvaluationDataset, EvaluationRun +from app.models.stt_evaluation import ( + EvaluationType, + STTEvaluationRunPublic, +) + +logger = logging.getLogger(__name__) + + +def create_stt_run( + *, + session: Session, + run_name: str, + dataset_id: int, + dataset_name: str, + org_id: int, + project_id: int, + providers: list[str], + language: str | None = None, + total_items: int = 0, +) -> EvaluationRun: + """Create a new STT evaluation run. + + Args: + session: Database session + run_name: Name for the run + dataset_id: ID of the dataset to evaluate + dataset_name: Name of the dataset + org_id: Organization ID + project_id: Project ID + providers: List of STT providers to use + language: Optional language override + total_items: Total number of items to process + + Returns: + EvaluationRun: Created run + """ + logger.info( + f"[create_stt_run] Creating STT evaluation run | " + f"run_name: {run_name}, dataset_id: {dataset_id}, " + f"providers: {providers}" + ) + + run = EvaluationRun( + run_name=run_name, + dataset_name=dataset_name, + dataset_id=dataset_id, + type=EvaluationType.STT.value, + language=language, + providers=providers, + status="pending", + total_items=total_items, + processed_samples=0, + organization_id=org_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[create_stt_run] STT evaluation run created | " + f"run_id: {run.id}, run_name: {run_name}" + ) + + return run + + +def get_stt_run_by_id( + *, + session: Session, + run_id: int, + org_id: int, + project_id: int, +) -> EvaluationRun | None: + """Get an STT evaluation run by ID. + + Args: + session: Database session + run_id: Run ID + org_id: Organization ID + project_id: Project ID + + Returns: + EvaluationRun | None: Run if found + """ + statement = select(EvaluationRun).where( + EvaluationRun.id == run_id, + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.STT.value, + ) + + return session.exec(statement).one_or_none() + + +def list_stt_runs( + *, + session: Session, + org_id: int, + project_id: int, + dataset_id: int | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, +) -> tuple[list[STTEvaluationRunPublic], int]: + """List STT evaluation runs for a project. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + dataset_id: Optional filter by dataset + status: Optional filter by status + limit: Maximum results to return + offset: Number of results to skip + + Returns: + tuple[list[STTEvaluationRunPublic], int]: Runs and total count + """ + # Build base where clause + where_clauses = [ + EvaluationRun.organization_id == org_id, + EvaluationRun.project_id == project_id, + EvaluationRun.type == EvaluationType.STT.value, + ] + + if dataset_id is not None: + where_clauses.append(EvaluationRun.dataset_id == dataset_id) + + if status is not None: + where_clauses.append(EvaluationRun.status == status) + + # Get total count + count_stmt = select(func.count(EvaluationRun.id)).where(*where_clauses) + total = session.exec(count_stmt).one() + + # Get runs + statement = ( + select(EvaluationRun) + .where(*where_clauses) + .order_by(EvaluationRun.inserted_at.desc()) + .offset(offset) + .limit(limit) + ) + + runs = session.exec(statement).all() + + # Convert to public models + result = [ + STTEvaluationRunPublic( + id=run.id, + run_name=run.run_name, + dataset_name=run.dataset_name, + type=run.type, + language=run.language, + providers=run.providers, + dataset_id=run.dataset_id, + status=run.status, + total_items=run.total_items, + processed_samples=run.processed_samples, + score=run.score, + error_message=run.error_message, + organization_id=run.organization_id, + project_id=run.project_id, + inserted_at=run.inserted_at, + updated_at=run.updated_at, + ) + for run in runs + ] + + return result, total + + +def update_stt_run( + *, + session: Session, + run_id: int, + status: str | None = None, + processed_samples: int | None = None, + score: dict[str, Any] | None = None, + error_message: str | None = None, + object_store_url: str | None = None, +) -> EvaluationRun | None: + """Update an STT evaluation run. + + Args: + session: Database session + run_id: Run ID + status: New status + processed_samples: Number of processed samples + score: Score data + error_message: Error message + object_store_url: URL to stored results + + Returns: + EvaluationRun | None: Updated run + """ + statement = select(EvaluationRun).where(EvaluationRun.id == run_id) + run = session.exec(statement).one_or_none() + + if not run: + return None + + if status is not None: + run.status = status + + if processed_samples is not None: + run.processed_samples = processed_samples + + if score is not None: + run.score = score + + if error_message is not None: + run.error_message = error_message + + if object_store_url is not None: + run.object_store_url = object_store_url + + run.updated_at = now() + + session.add(run) + session.commit() + session.refresh(run) + + logger.info( + f"[update_stt_run] STT run updated | " + f"run_id: {run_id}, status: {run.status}, " + f"processed_samples: {run.processed_samples}" + ) + + return run + + +def increment_processed_samples( + *, + session: Session, + run_id: int, + increment: int = 1, +) -> EvaluationRun | None: + """Increment the processed_samples counter for a run. + + Args: + session: Database session + run_id: Run ID + increment: Amount to increment by + + Returns: + EvaluationRun | None: Updated run + """ + statement = select(EvaluationRun).where(EvaluationRun.id == run_id) + run = session.exec(statement).one_or_none() + + if not run: + return None + + run.processed_samples = (run.processed_samples or 0) + increment + run.updated_at = now() + + session.add(run) + session.commit() + session.refresh(run) + + return run + + +def get_pending_stt_runs( + *, + session: Session, + org_id: int | None = None, +) -> list[EvaluationRun]: + """Get all pending STT evaluation runs. + + Args: + session: Database session + org_id: Optional filter by organization + + Returns: + list[EvaluationRun]: Pending runs + """ + where_clauses = [ + EvaluationRun.type == EvaluationType.STT.value, + EvaluationRun.status.in_(["pending", "processing"]), + ] + + if org_id is not None: + where_clauses.append(EvaluationRun.organization_id == org_id) + + statement = select(EvaluationRun).where(*where_clauses) + + return list(session.exec(statement).all()) diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py new file mode 100644 index 000000000..24a419af2 --- /dev/null +++ b/backend/app/models/stt_evaluation.py @@ -0,0 +1,414 @@ +"""STT Evaluation models for Speech-to-Text evaluation feature.""" + +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field +from sqlalchemy import Column, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import Field as SQLField +from sqlmodel import Relationship, SQLModel + +from app.core.util import now + +if TYPE_CHECKING: + from .evaluation import EvaluationDataset, EvaluationRun + from .organization import Organization + from .project import Project + + +class EvaluationType(str, Enum): + """Type of evaluation dataset/run.""" + + TEXT = "text" + STT = "stt" + TTS = "tts" + + +class STTResultStatus(str, Enum): + """Status of an STT result.""" + + PENDING = "pending" + COMPLETED = "completed" + FAILED = "failed" + + +# Database Models + + +class STTSample(SQLModel, table=True): + """Database table for STT audio samples within a dataset.""" + + __tablename__ = "stt_sample" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the STT sample"}, + ) + + # Audio file reference + object_store_url: str = SQLField( + description="S3 URL of the audio file", + sa_column_kwargs={"comment": "S3 URL of the audio file"}, + ) + + # Language (can be different per sample within a dataset) + language: str | None = SQLField( + default=None, + max_length=10, + description="ISO 639-1 language code for this sample", + sa_column_kwargs={"comment": "ISO 639-1 language code for this sample"}, + ) + + # Ground truth transcription (optional, for evaluation) + ground_truth: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Reference transcription for comparison (optional)", + ), + description="Reference transcription for comparison", + ) + + # Audio metadata + duration_seconds: float | None = SQLField( + default=None, + description="Audio duration in seconds", + sa_column_kwargs={"comment": "Audio duration in seconds"}, + ) + + sample_metadata: dict[str, Any] | None = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=True, + comment="Additional metadata (format, bitrate, original filename, etc.)", + ), + description="Additional metadata about the audio sample", + ) + + # Foreign keys + dataset_id: int = SQLField( + foreign_key="evaluation_dataset.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the parent evaluation dataset"}, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + # Timestamps + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the sample was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the sample was last updated"}, + ) + + # Relationships + dataset: "EvaluationDataset" = Relationship() + organization: "Organization" = Relationship() + project: "Project" = Relationship() + results: list["STTResult"] = Relationship(back_populates="sample") + + +class STTResult(SQLModel, table=True): + """Database table for STT transcription results.""" + + __tablename__ = "stt_result" + + id: int = SQLField( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the STT result"}, + ) + + # Transcription output + transcription: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Generated transcription from STT provider", + ), + description="Generated transcription from STT provider", + ) + + # Provider info + provider: str = SQLField( + max_length=50, + description="STT provider used (e.g., gemini-2.5-pro)", + sa_column_kwargs={"comment": "STT provider used (e.g., gemini-2.5-pro)"}, + ) + + # Status + status: str = SQLField( + default=STTResultStatus.PENDING.value, + max_length=20, + description="Result status: pending, completed, failed", + sa_column_kwargs={"comment": "Result status: pending, completed, failed"}, + ) + + # Metrics (null for Phase 1) + wer: float | None = SQLField( + default=None, + description="Word Error Rate (null for Phase 1)", + sa_column_kwargs={"comment": "Word Error Rate (null for Phase 1)"}, + ) + cer: float | None = SQLField( + default=None, + description="Character Error Rate (null for Phase 1)", + sa_column_kwargs={"comment": "Character Error Rate (null for Phase 1)"}, + ) + + # Human feedback + is_correct: bool | None = SQLField( + default=None, + description="Human feedback: transcription correctness", + sa_column_kwargs={ + "comment": "Human feedback: transcription correctness (null=not reviewed)" + }, + ) + comment: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Human feedback comment", + ), + description="Human feedback comment", + ) + + # Provider response metadata + provider_metadata: dict[str, Any] | None = SQLField( + default_factory=dict, + sa_column=Column( + JSONB, + nullable=True, + comment="Provider-specific response metadata (tokens, latency, etc.)", + ), + description="Provider-specific response metadata", + ) + + # Error message if failed + error_message: str | None = SQLField( + default=None, + sa_column=Column( + Text, + nullable=True, + comment="Error message if transcription failed", + ), + description="Error message if transcription failed", + ) + + # Foreign keys + stt_sample_id: int = SQLField( + foreign_key="stt_sample.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the STT sample"}, + ) + evaluation_run_id: int = SQLField( + foreign_key="evaluation_run.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the evaluation run"}, + ) + organization_id: int = SQLField( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = SQLField( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + # Timestamps + inserted_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was created"}, + ) + updated_at: datetime = SQLField( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the result was last updated"}, + ) + + # Relationships + sample: "STTSample" = Relationship(back_populates="results") + evaluation_run: "EvaluationRun" = Relationship() + organization: "Organization" = Relationship() + project: "Project" = Relationship() + + +# Pydantic Models for API + + +class STTSampleCreate(BaseModel): + """Request model for creating an STT sample.""" + + object_store_url: str = Field(..., description="S3 URL of the audio file") + language: str | None = Field(None, description="ISO 639-1 language code") + ground_truth: str | None = Field( + None, description="Reference transcription (optional)" + ) + duration_seconds: float | None = Field( + None, description="Audio duration in seconds" + ) + sample_metadata: dict[str, Any] | None = Field( + None, description="Additional metadata" + ) + + +class STTSamplePublic(BaseModel): + """Public model for STT samples.""" + + id: int + object_store_url: str + language: str | None + ground_truth: str | None + duration_seconds: float | None + sample_metadata: dict[str, Any] | None + dataset_id: int + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTResultPublic(BaseModel): + """Public model for STT results.""" + + id: int + transcription: str | None + provider: str + status: str + wer: float | None + cer: float | None + is_correct: bool | None + comment: str | None + provider_metadata: dict[str, Any] | None + error_message: str | None + stt_sample_id: int + evaluation_run_id: int + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTResultWithSample(STTResultPublic): + """STT result with embedded sample data.""" + + sample: STTSamplePublic + + +class STTFeedbackUpdate(BaseModel): + """Request model for updating human feedback on a result.""" + + is_correct: bool | None = Field(None, description="Is the transcription correct?") + comment: str | None = Field(None, description="Feedback comment") + + +class STTDatasetCreate(BaseModel): + """Request model for creating an STT dataset.""" + + name: str = Field(..., description="Dataset name", min_length=1) + description: str | None = Field(None, description="Dataset description") + language: str | None = Field(None, description="Default language for the dataset") + samples: list[STTSampleCreate] = Field( + ..., description="List of audio samples", min_length=1 + ) + + +class STTDatasetPublic(BaseModel): + """Public model for STT datasets.""" + + id: int + name: str + description: str | None + type: str + language: str | None + dataset_metadata: dict[str, Any] + sample_count: int = Field(0, description="Number of samples in the dataset") + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTDatasetWithSamples(STTDatasetPublic): + """STT dataset with embedded samples.""" + + samples: list[STTSamplePublic] + + +class STTEvaluationRunCreate(BaseModel): + """Request model for starting an STT evaluation run.""" + + run_name: str = Field(..., description="Name for this evaluation run", min_length=1) + dataset_id: int = Field(..., description="ID of the STT dataset to evaluate") + providers: list[str] = Field( + default=["gemini-2.5-pro"], + description="List of STT providers to use", + ) + language: str | None = Field(None, description="Override language for all samples") + + +class STTEvaluationRunPublic(BaseModel): + """Public model for STT evaluation runs.""" + + id: int + run_name: str + dataset_name: str + type: str + language: str | None + providers: list[str] | None + dataset_id: int + status: str + total_items: int + processed_samples: int + score: dict[str, Any] | None + error_message: str | None + organization_id: int + project_id: int + inserted_at: datetime + updated_at: datetime + + +class STTEvaluationRunWithResults(STTEvaluationRunPublic): + """STT evaluation run with embedded results.""" + + results: list[STTResultWithSample] + results_total: int = Field(0, description="Total number of results") + + +class AudioUploadResponse(BaseModel): + """Response model for audio file upload.""" + + s3_url: str = Field(..., description="S3 URL of the uploaded audio file") + filename: str = Field(..., description="Original filename") + size_bytes: int = Field(..., description="File size in bytes") + content_type: str = Field(..., description="MIME type of the audio file") diff --git a/backend/app/services/stt_evaluations/__init__.py b/backend/app/services/stt_evaluations/__init__.py new file mode 100644 index 000000000..0e22a1c2e --- /dev/null +++ b/backend/app/services/stt_evaluations/__init__.py @@ -0,0 +1 @@ +"""STT Evaluation services.""" diff --git a/backend/app/services/stt_evaluations/audio.py b/backend/app/services/stt_evaluations/audio.py new file mode 100644 index 000000000..b7223fab6 --- /dev/null +++ b/backend/app/services/stt_evaluations/audio.py @@ -0,0 +1,212 @@ +"""Audio file validation and upload service for STT evaluation.""" + +import logging +import uuid +from pathlib import Path + +from fastapi import UploadFile +from sqlmodel import Session + +from app.core.cloud.storage import get_cloud_storage +from app.core.exception_handlers import HTTPException +from app.models.stt_evaluation import AudioUploadResponse +from app.services.stt_evaluations.constants import ( + MAX_FILE_SIZE_BYTES, + MIME_TO_EXTENSION, + SUPPORTED_AUDIO_FORMATS, +) + +logger = logging.getLogger(__name__) + + +class AudioValidationError(Exception): + """Exception raised for audio validation errors.""" + + pass + + +def get_extension_from_filename(filename: str) -> str | None: + """Extract and validate file extension from filename. + + Args: + filename: Original filename + + Returns: + str: Lowercase file extension (without dot) + None: If no valid extension found + """ + if not filename or "." not in filename: + return None + return filename.rsplit(".", 1)[-1].lower() + + +def get_extension_from_content_type(content_type: str) -> str | None: + """Get file extension from MIME content type. + + Args: + content_type: MIME content type + + Returns: + str: File extension + None: If content type not recognized + """ + if not content_type: + return None + return MIME_TO_EXTENSION.get(content_type.lower()) + + +def validate_audio_file(file: UploadFile) -> str: + """Validate an uploaded audio file. + + Args: + file: FastAPI UploadFile object + + Returns: + str: Validated file extension + + Raises: + AudioValidationError: If file is invalid + """ + # Check filename exists + if not file.filename: + logger.error("[validate_audio_file] No filename provided") + raise AudioValidationError("Filename is required") + + # Get extension from filename + extension = get_extension_from_filename(file.filename) + + # If no extension from filename, try content type + if not extension: + extension = get_extension_from_content_type(file.content_type) + + # Validate extension is supported + if not extension or extension not in SUPPORTED_AUDIO_FORMATS: + supported = ", ".join(sorted(SUPPORTED_AUDIO_FORMATS)) + logger.error( + f"[validate_audio_file] Unsupported audio format | " + f"filename: {file.filename}, extension: {extension}, " + f"content_type: {file.content_type}" + ) + raise AudioValidationError( + f"Unsupported audio format: {extension or 'unknown'}. " + f"Supported formats: {supported}" + ) + + # Check file size (if available) + if file.size and file.size > MAX_FILE_SIZE_BYTES: + max_mb = MAX_FILE_SIZE_BYTES / (1024 * 1024) + file_mb = file.size / (1024 * 1024) + logger.error( + f"[validate_audio_file] File too large | " + f"filename: {file.filename}, size_mb: {file_mb:.2f}, max_mb: {max_mb}" + ) + raise AudioValidationError( + f"File too large: {file_mb:.2f} MB. Maximum size: {max_mb:.0f} MB" + ) + + logger.info( + f"[validate_audio_file] Audio file validated | " + f"filename: {file.filename}, extension: {extension}, " + f"content_type: {file.content_type}" + ) + + return extension + + +def upload_audio_file( + session: Session, + file: UploadFile, + project_id: int, +) -> AudioUploadResponse: + """Upload an audio file to S3. + + Args: + session: Database session + file: FastAPI UploadFile object + project_id: Project ID + + Returns: + AudioUploadResponse: Upload result with S3 URL + + Raises: + HTTPException: If validation or upload fails + """ + logger.info( + f"[upload_audio_file] Starting audio upload | " + f"project_id: {project_id}, filename: {file.filename}" + ) + + try: + # Validate the audio file + extension = validate_audio_file(file) + except AudioValidationError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Generate unique filename + file_uuid = uuid.uuid4() + new_filename = f"{file_uuid}.{extension}" + + # Construct S3 path: stt/audio/{uuid}.{ext} + file_path = Path("stt") / "audio" / new_filename + + try: + # Get cloud storage for project + storage = get_cloud_storage(session=session, project_id=project_id) + + # Upload file to S3 + destination = storage.put(source=file, file_path=file_path) + s3_url = str(destination) + + # Get file size + try: + size_kb = storage.get_file_size_kb(s3_url) + size_bytes = int(size_kb * 1024) + except Exception: + # If we can't get size from S3, use the upload file size + size_bytes = file.size or 0 + + logger.info( + f"[upload_audio_file] Audio uploaded successfully | " + f"project_id: {project_id}, s3_url: {s3_url}, size_bytes: {size_bytes}" + ) + + return AudioUploadResponse( + s3_url=s3_url, + filename=file.filename or new_filename, + size_bytes=size_bytes, + content_type=file.content_type or f"audio/{extension}", + ) + + except Exception as e: + logger.error( + f"[upload_audio_file] Failed to upload audio | " + f"project_id: {project_id}, error: {str(e)}" + ) + raise HTTPException( + status_code=500, + detail="Failed to upload audio file. Please try again later.", + ) + + +def validate_s3_audio_url(url: str) -> bool: + """Validate that a URL is a valid S3 audio file URL. + + Args: + url: S3 URL to validate + + Returns: + bool: True if URL appears valid + """ + if not url: + return False + + # Check URL format + if not url.startswith("s3://"): + return False + + # Check file extension + extension = get_extension_from_filename(url) + if not extension or extension not in SUPPORTED_AUDIO_FORMATS: + return False + + return True diff --git a/backend/app/services/stt_evaluations/constants.py b/backend/app/services/stt_evaluations/constants.py new file mode 100644 index 000000000..41da0b5e7 --- /dev/null +++ b/backend/app/services/stt_evaluations/constants.py @@ -0,0 +1,32 @@ +"""Shared constants for STT evaluation services.""" + +# Supported audio formats for STT evaluation +SUPPORTED_AUDIO_FORMATS: set[str] = {"mp3", "wav", "flac", "m4a", "ogg", "webm"} + +# Maximum audio file size (200 MB) +MAX_FILE_SIZE_BYTES: int = 200 * 1024 * 1024 + +# Mapping from file extension to MIME type +EXTENSION_TO_MIME: dict[str, str] = { + "mp3": "audio/mp3", + "wav": "audio/wav", + "flac": "audio/flac", + "m4a": "audio/mp4", + "ogg": "audio/ogg", + "webm": "audio/webm", +} + +# Mapping from MIME type to file extension +MIME_TO_EXTENSION: dict[str, str] = { + "audio/mp3": "mp3", + "audio/mpeg": "mp3", + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/wave": "wav", + "audio/flac": "flac", + "audio/x-flac": "flac", + "audio/mp4": "m4a", + "audio/x-m4a": "m4a", + "audio/ogg": "ogg", + "audio/webm": "webm", +} diff --git a/backend/app/services/stt_evaluations/gemini/__init__.py b/backend/app/services/stt_evaluations/gemini/__init__.py new file mode 100644 index 000000000..a0b35b1b6 --- /dev/null +++ b/backend/app/services/stt_evaluations/gemini/__init__.py @@ -0,0 +1,6 @@ +"""Gemini integration for STT evaluation.""" + +from .client import GeminiClient +from .files import GeminiFilesManager + +__all__ = ["GeminiClient", "GeminiFilesManager"] diff --git a/backend/app/services/stt_evaluations/gemini/client.py b/backend/app/services/stt_evaluations/gemini/client.py new file mode 100644 index 000000000..5be4b583a --- /dev/null +++ b/backend/app/services/stt_evaluations/gemini/client.py @@ -0,0 +1,113 @@ +"""Gemini client wrapper for STT evaluation.""" + +import logging +from typing import Any + +from google import genai +from sqlmodel import Session + +from app.core.exception_handlers import HTTPException +from app.crud.credentials import get_provider_credential + +logger = logging.getLogger(__name__) + + +class GeminiClientError(Exception): + """Exception raised for Gemini client errors.""" + + pass + + +class GeminiClient: + """Wrapper for Google GenAI client with credential management.""" + + def __init__(self, api_key: str): + """Initialize Gemini client with API key. + + Args: + api_key: Google AI API key + """ + self._api_key = api_key + self._client = genai.Client(api_key=api_key) + + @property + def client(self) -> genai.Client: + """Get the underlying GenAI client.""" + return self._client + + @classmethod + def from_credentials( + cls, + session: Session, + org_id: int, + project_id: int, + ) -> "GeminiClient": + """Create client from stored credentials. + + Args: + session: Database session + org_id: Organization ID + project_id: Project ID + + Returns: + GeminiClient: Configured Gemini client + + Raises: + HTTPException: If credentials not found + GeminiClientError: If credentials are invalid + """ + logger.info( + f"[from_credentials] Fetching Gemini credentials | " + f"org_id: {org_id}, project_id: {project_id}" + ) + + credentials = get_provider_credential( + session=session, + org_id=org_id, + project_id=project_id, + provider="gemini", + ) + + if not credentials: + logger.error( + f"[from_credentials] Gemini credentials not found | " + f"org_id: {org_id}, project_id: {project_id}" + ) + raise HTTPException( + status_code=404, + detail="Gemini credentials not configured for this project", + ) + + api_key = credentials.get("api_key") + if not api_key: + logger.error( + f"[from_credentials] Invalid Gemini credentials (missing api_key) | " + f"org_id: {org_id}, project_id: {project_id}" + ) + raise GeminiClientError("Invalid Gemini credentials: missing api_key") + + logger.info( + f"[from_credentials] Gemini client created successfully | " + f"org_id: {org_id}, project_id: {project_id}" + ) + return cls(api_key=api_key) + + def validate_connection(self) -> bool: + """Validate that the client can connect to Gemini. + + Returns: + bool: True if connection is valid + """ + try: + # List models to verify connection + models = list(self._client.models.list()) + logger.info( + f"[validate_connection] Connection validated | " + f"available_models_count: {len(models)}" + ) + return True + except Exception as e: + logger.error( + f"[validate_connection] Connection validation failed | error: {str(e)}" + ) + return False diff --git a/backend/app/services/stt_evaluations/gemini/files.py b/backend/app/services/stt_evaluations/gemini/files.py new file mode 100644 index 000000000..dadbef5ab --- /dev/null +++ b/backend/app/services/stt_evaluations/gemini/files.py @@ -0,0 +1,271 @@ +"""Google Files API operations for STT evaluation.""" + +import logging +import os +import tempfile +from typing import BinaryIO + +import requests +from google import genai +from google.genai import types + +from app.services.stt_evaluations.constants import EXTENSION_TO_MIME + +logger = logging.getLogger(__name__) + + +class GeminiFilesError(Exception): + """Exception raised for Gemini Files API errors.""" + + pass + + +def get_mime_type(file_path: str) -> str: + """Get MIME type based on file extension. + + Args: + file_path: Path or URL of the audio file + + Returns: + str: MIME type string + """ + extension = file_path.lower().split(".")[-1] + return EXTENSION_TO_MIME.get(extension, "audio/mpeg") + + +def get_extension(file_path: str) -> str: + """Get file extension from path or URL. + + Args: + file_path: Path or URL of the audio file + + Returns: + str: File extension (lowercase, without dot) + """ + return file_path.lower().split(".")[-1] + + +class GeminiFilesManager: + """Manage file uploads to Google Files API.""" + + # Base URL for Gemini Files API + FILES_API_BASE = "https://generativelanguage.googleapis.com/v1beta" + + def __init__(self, client: genai.Client): + """Initialize files manager. + + Args: + client: Gemini client instance + """ + self._client = client + + def upload_from_bytes( + self, + content: bytes, + filename: str, + mime_type: str | None = None, + ) -> str: + """Upload audio content directly to Google Files API. + + Args: + content: Audio file content as bytes + filename: Display name for the file + mime_type: MIME type (auto-detected if not provided) + + Returns: + str: Full Google Files API URI (HTTPS format for batch API) + + Raises: + GeminiFilesError: If upload fails + """ + if mime_type is None: + mime_type = get_mime_type(filename) + + logger.info( + f"[upload_from_bytes] Uploading file to Google Files API | " + f"filename: {filename}, mime_type: {mime_type}, size_bytes: {len(content)}" + ) + + try: + # Write content to a temporary file + extension = get_extension(filename) + with tempfile.NamedTemporaryFile( + suffix=f".{extension}", delete=False + ) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + try: + # Upload to Google Files API + uploaded_file = self._client.files.upload( + file=tmp_path, + config=types.UploadFileConfig( + display_name=filename, + mime_type=mime_type, + ), + ) + + # CRITICAL: Return the full HTTPS URI for batch API compatibility + # The batch API requires full URLs, not short form like "files/xxx" + file_uri = f"{self.FILES_API_BASE}/{uploaded_file.name}" + + logger.info( + f"[upload_from_bytes] File uploaded successfully | " + f"filename: {filename}, file_uri: {file_uri}" + ) + + return file_uri + + finally: + # Clean up temporary file + os.unlink(tmp_path) + + except Exception as e: + logger.error( + f"[upload_from_bytes] Failed to upload file | " + f"filename: {filename}, error: {str(e)}" + ) + raise GeminiFilesError(f"Failed to upload file {filename}: {str(e)}") from e + + def upload_from_url( + self, + signed_url: str, + filename: str, + mime_type: str | None = None, + ) -> str: + """Upload audio from a signed URL to Google Files API. + + Args: + signed_url: Signed URL to download the audio file + filename: Display name for the file + mime_type: MIME type (auto-detected if not provided) + + Returns: + str: Full Google Files API URI (HTTPS format for batch API) + + Raises: + GeminiFilesError: If download or upload fails + """ + logger.info( + f"[upload_from_url] Downloading file from URL | filename: {filename}" + ) + + try: + # Download the file from the signed URL + response = requests.get(signed_url, timeout=300) # 5 minute timeout + response.raise_for_status() + + # Upload to Google Files API + return self.upload_from_bytes( + content=response.content, + filename=filename, + mime_type=mime_type, + ) + + except requests.RequestException as e: + logger.error( + f"[upload_from_url] Failed to download file | " + f"filename: {filename}, error: {str(e)}" + ) + raise GeminiFilesError( + f"Failed to download file {filename}: {str(e)}" + ) from e + + def upload_from_stream( + self, + stream: BinaryIO, + filename: str, + mime_type: str | None = None, + ) -> str: + """Upload audio from a stream to Google Files API. + + Args: + stream: Binary stream (e.g., from S3 StreamingBody) + filename: Display name for the file + mime_type: MIME type (auto-detected if not provided) + + Returns: + str: Full Google Files API URI (HTTPS format for batch API) + + Raises: + GeminiFilesError: If upload fails + """ + try: + content = stream.read() + return self.upload_from_bytes( + content=content, + filename=filename, + mime_type=mime_type, + ) + except Exception as e: + logger.error( + f"[upload_from_stream] Failed to read stream | " + f"filename: {filename}, error: {str(e)}" + ) + raise GeminiFilesError( + f"Failed to read stream for {filename}: {str(e)}" + ) from e + + def delete_file(self, file_uri: str) -> bool: + """Delete a file from Google Files API. + + Args: + file_uri: Full Google Files API URI or short form (files/xxx) + + Returns: + bool: True if deletion was successful + """ + # Extract the file name from the URI + if file_uri.startswith(self.FILES_API_BASE): + file_name = file_uri.replace(f"{self.FILES_API_BASE}/", "") + else: + file_name = file_uri + + logger.info(f"[delete_file] Deleting file | file_name: {file_name}") + + try: + self._client.files.delete(name=file_name) + logger.info( + f"[delete_file] File deleted successfully | file_name: {file_name}" + ) + return True + except Exception as e: + logger.warning( + f"[delete_file] Failed to delete file | " + f"file_name: {file_name}, error: {str(e)}" + ) + return False + + def get_file_status(self, file_uri: str) -> dict: + """Get the status of a file in Google Files API. + + Args: + file_uri: Full Google Files API URI or short form (files/xxx) + + Returns: + dict: File metadata including state + """ + # Extract the file name from the URI + if file_uri.startswith(self.FILES_API_BASE): + file_name = file_uri.replace(f"{self.FILES_API_BASE}/", "") + else: + file_name = file_uri + + try: + file_info = self._client.files.get(name=file_name) + return { + "name": file_info.name, + "display_name": file_info.display_name, + "mime_type": file_info.mime_type, + "size_bytes": file_info.size_bytes, + "state": file_info.state.name if file_info.state else None, + "uri": file_info.uri, + } + except Exception as e: + logger.error( + f"[get_file_status] Failed to get file status | " + f"file_name: {file_name}, error: {str(e)}" + ) + raise GeminiFilesError( + f"Failed to get file status for {file_name}: {str(e)}" + ) from e From f1df7f97b2001cc07d0fa7a8311ee2d44820c7a8 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 17:12:59 +0530 Subject: [PATCH 02/36] fix migration naming --- ...luation_tables.py => 043_add_stt_evaluation_tables.py} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename backend/app/alembic/versions/{042_add_stt_evaluation_tables.py => 043_add_stt_evaluation_tables.py} (99%) diff --git a/backend/app/alembic/versions/042_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py similarity index 99% rename from backend/app/alembic/versions/042_add_stt_evaluation_tables.py rename to backend/app/alembic/versions/043_add_stt_evaluation_tables.py index ff95b078e..b76f2da9e 100644 --- a/backend/app/alembic/versions/042_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -1,7 +1,7 @@ """add stt evaluation tables -Revision ID: 042 -Revises: 041 +Revision ID: 043 +Revises: 042 Create Date: 2026-01-28 12:00:00.000000 """ @@ -12,8 +12,8 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = "042" -down_revision = "041" +revision = "043" +down_revision = "042" branch_labels = None depends_on = None From cda06118d679f1627716df394a0123fa69415412 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 17:45:54 +0530 Subject: [PATCH 03/36] fixing endpoints --- backend/app/api/main.py | 2 ++ backend/app/api/routes/stt_evaluations/dataset.py | 1 - backend/app/celery/tasks/stt_evaluation.py | 3 +-- backend/app/core/batch/__init__.py | 3 +++ backend/app/crud/stt_evaluations/__init__.py | 4 ++++ backend/pyproject.toml | 2 ++ 6 files changed, 12 insertions(+), 3 deletions(-) diff --git a/backend/app/api/main.py b/backend/app/api/main.py index bcd64eb58..804340fd1 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -25,6 +25,7 @@ collection_job, ) from app.api.routes.evaluations import dataset as evaluation_dataset, evaluation +from app.api.routes import stt_evaluations from app.core.config import settings api_router = APIRouter() @@ -39,6 +40,7 @@ api_router.include_router(doc_transformation_job.router) api_router.include_router(evaluation_dataset.router) api_router.include_router(evaluation.router) +api_router.include_router(stt_evaluations.router) api_router.include_router(llm.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index c38b22f76..b76535c14 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -14,7 +14,6 @@ list_stt_datasets, get_samples_by_dataset_id, get_sample_count_for_dataset, - update_dataset_metadata, ) from app.models.stt_evaluation import ( STTDatasetCreate, diff --git a/backend/app/celery/tasks/stt_evaluation.py b/backend/app/celery/tasks/stt_evaluation.py index 585431838..5a3bab815 100644 --- a/backend/app/celery/tasks/stt_evaluation.py +++ b/backend/app/celery/tasks/stt_evaluation.py @@ -8,8 +8,7 @@ from sqlmodel import Session, select from app.celery.celery_app import celery_app -from app.core.batch import GeminiBatchProvider -from app.core.batch.gemini import BatchJobState +from app.core.batch import BatchJobState, GeminiBatchProvider from app.core.db import engine from app.core.cloud.storage import get_cloud_storage from app.crud.stt_evaluations import ( diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 0cd10bd51..25a33b6d7 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -1,6 +1,7 @@ """Batch processing infrastructure for LLM providers.""" from .base import BatchProvider +from .gemini import BatchJobState, GeminiBatchProvider from .openai import OpenAIBatchProvider from .operations import ( download_batch_results, @@ -12,6 +13,8 @@ __all__ = [ "BatchProvider", + "BatchJobState", + "GeminiBatchProvider", "OpenAIBatchProvider", "start_batch_job", "download_batch_results", diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py index 6e883936c..24ee863f0 100644 --- a/backend/app/crud/stt_evaluations/__init__.py +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -6,6 +6,7 @@ get_stt_dataset_by_id, list_stt_datasets, get_samples_by_dataset_id, + get_sample_count_for_dataset, ) from .run import ( create_stt_run, @@ -19,6 +20,7 @@ get_results_by_run_id, update_stt_result, update_human_feedback, + count_results_by_status, ) __all__ = [ @@ -28,6 +30,7 @@ "get_stt_dataset_by_id", "list_stt_datasets", "get_samples_by_dataset_id", + "get_sample_count_for_dataset", # Run "create_stt_run", "get_stt_run_by_id", @@ -39,4 +42,5 @@ "get_results_by_run_id", "update_stt_result", "update_human_feedback", + "count_results_by_status", ] diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6030fc0a1..119ee9116 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "celery>=5.3.0,<6.0.0", "redis>=5.0.0,<6.0.0", "flower>=2.0.1", + "google-genai>=1.0.0", + "requests>=2.31.0", ] [tool.uv] From ad5779f96372c8d555db092493322900545bd343 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 17:51:41 +0530 Subject: [PATCH 04/36] update dataset endpoint --- backend/app/api/routes/stt_evaluations/dataset.py | 3 +-- backend/app/crud/stt_evaluations/dataset.py | 3 --- backend/app/models/stt_evaluation.py | 7 ------- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index b76535c14..c72b10984 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -39,8 +39,7 @@ Each sample requires: - **object_store_url**: S3 URL of the audio file (from /evaluations/stt/files/audio endpoint) -- **language**: ISO 639-1 language code (optional) -- **ground_truth**: Reference transcription (optional, for Phase 2 WER/CER) +- **ground_truth**: Reference transcription (optional, for WER/CER metrics) """, ) def create_dataset( diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index 35b8f6b80..ae9e4604e 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -119,10 +119,7 @@ def create_stt_samples( for sample_data in samples: sample = STTSample( object_store_url=sample_data.object_store_url, - language=sample_data.language, ground_truth=sample_data.ground_truth, - duration_seconds=sample_data.duration_seconds, - sample_metadata=sample_data.sample_metadata or {}, dataset_id=dataset_id, organization_id=org_id, project_id=project_id, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 24a419af2..12a3d7719 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -270,16 +270,9 @@ class STTSampleCreate(BaseModel): """Request model for creating an STT sample.""" object_store_url: str = Field(..., description="S3 URL of the audio file") - language: str | None = Field(None, description="ISO 639-1 language code") ground_truth: str | None = Field( None, description="Reference transcription (optional)" ) - duration_seconds: float | None = Field( - None, description="Audio duration in seconds" - ) - sample_metadata: dict[str, Any] | None = Field( - None, description="Additional metadata" - ) class STTSamplePublic(BaseModel): From 01e2beb0d999dfec72cc5282e351e799b24ee440 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 17:55:32 +0530 Subject: [PATCH 05/36] update types --- backend/app/models/evaluation.py | 38 ++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 6ae4542fb..8b24fe9e4 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -101,6 +101,18 @@ class EvaluationDataset(SQLModel, table=True): description="Optional description of the dataset", sa_column_kwargs={"comment": "Description of the dataset"}, ) + type: str = SQLField( + default="text", + max_length=20, + description="Evaluation type: text, stt, or tts", + sa_column_kwargs={"comment": "Evaluation type: text, stt, or tts"}, + ) + language: str | None = SQLField( + default=None, + max_length=10, + description="ISO 639-1 language code (e.g., en, hi)", + sa_column_kwargs={"comment": "ISO 639-1 language code (e.g., en, hi)"}, + ) # Dataset metadata stored as JSONB dataset_metadata: dict[str, Any] = SQLField( @@ -193,6 +205,32 @@ class EvaluationRun(SQLModel, table=True): description="Name of the Langfuse dataset", sa_column_kwargs={"comment": "Name of the Langfuse dataset used"}, ) + type: str = SQLField( + default="text", + max_length=20, + description="Evaluation type: text, stt, or tts", + sa_column_kwargs={"comment": "Evaluation type: text, stt, or tts"}, + ) + language: str | None = SQLField( + default=None, + max_length=10, + description="ISO 639-1 language code", + sa_column_kwargs={"comment": "ISO 639-1 language code"}, + ) + providers: list[str] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="List of STT/TTS providers used (e.g., ['gemini-2.5-pro'])", + ), + description="List of STT/TTS providers used", + ) + processed_samples: int = SQLField( + default=0, + description="Number of samples processed so far", + sa_column_kwargs={"comment": "Number of samples processed so far"}, + ) config_id: UUID = SQLField( foreign_key="config.id", From 1637007660581587c76768214cb73968d02e3f37 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 17:59:28 +0530 Subject: [PATCH 06/36] updated dataset with URL --- backend/app/api/routes/stt_evaluations/dataset.py | 2 ++ backend/app/crud/stt_evaluations/dataset.py | 4 ++++ backend/app/models/stt_evaluation.py | 1 + 3 files changed, 7 insertions(+) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index c72b10984..58577e892 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -85,6 +85,7 @@ def create_dataset( description=dataset.description, type=dataset.type, language=dataset.language, + object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, sample_count=len(samples), organization_id=dataset.organization_id, @@ -190,6 +191,7 @@ def get_dataset( description=dataset.description, type=dataset.type, language=dataset.language, + object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, sample_count=samples_total, organization_id=dataset.organization_id, diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index ae9e4604e..8de773879 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -28,6 +28,7 @@ def create_stt_dataset( project_id: int, description: str | None = None, language: str | None = None, + object_store_url: str | None = None, dataset_metadata: dict[str, Any] | None = None, ) -> EvaluationDataset: """Create a new STT evaluation dataset. @@ -39,6 +40,7 @@ def create_stt_dataset( project_id: Project ID description: Optional description language: Optional default language code + object_store_url: Optional object store URL dataset_metadata: Optional metadata dict Returns: @@ -57,6 +59,7 @@ def create_stt_dataset( description=description, type=EvaluationType.STT.value, language=language, + object_store_url=object_store_url, dataset_metadata=dataset_metadata or {}, organization_id=org_id, project_id=project_id, @@ -227,6 +230,7 @@ def list_stt_datasets( description=dataset.description, type=dataset.type, language=dataset.language, + object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, sample_count=sample_count, organization_id=dataset.organization_id, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 12a3d7719..8286ccbbb 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -344,6 +344,7 @@ class STTDatasetPublic(BaseModel): description: str | None type: str language: str | None + object_store_url: str | None dataset_metadata: dict[str, Any] sample_count: int = Field(0, description="Number of samples in the dataset") organization_id: int From 36af7e998c079224100bfd4f3b48f22af976f31b Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 21:28:56 +0530 Subject: [PATCH 07/36] added few more testcases --- .../tests/api/routes/test_stt_evaluation.py | 534 ++++++++++++++++++ 1 file changed, 534 insertions(+) create mode 100644 backend/app/tests/api/routes/test_stt_evaluation.py diff --git a/backend/app/tests/api/routes/test_stt_evaluation.py b/backend/app/tests/api/routes/test_stt_evaluation.py new file mode 100644 index 000000000..9352e0726 --- /dev/null +++ b/backend/app/tests/api/routes/test_stt_evaluation.py @@ -0,0 +1,534 @@ +"""Tests for STT evaluation API routes.""" + +import pytest +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.models import EvaluationDataset +from app.models.stt_evaluation import STTSample, EvaluationType +from app.tests.utils.auth import TestAuthContext +from app.core.util import now + + +# Helper functions +def create_test_stt_dataset( + db: Session, + organization_id: int, + project_id: int, + name: str = "test_stt_dataset", + description: str | None = None, + language: str | None = "en", +) -> EvaluationDataset: + """Create a test STT dataset.""" + dataset = EvaluationDataset( + name=name, + description=description, + type=EvaluationType.STT.value, + language=language, + dataset_metadata={"sample_count": 0, "has_ground_truth_count": 0}, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(dataset) + db.commit() + db.refresh(dataset) + return dataset + + +def create_test_stt_sample( + db: Session, + dataset_id: int, + organization_id: int, + project_id: int, + object_store_url: str = "s3://test-bucket/audio/test.mp3", + ground_truth: str | None = None, +) -> STTSample: + """Create a test STT sample.""" + sample = STTSample( + object_store_url=object_store_url, + ground_truth=ground_truth, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + inserted_at=now(), + updated_at=now(), + ) + db.add(sample) + db.commit() + db.refresh(sample) + return sample + + +class TestSTTDatasetCreate: + """Test POST /evaluations/stt/datasets endpoint.""" + + def test_create_stt_dataset_success( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + ) -> None: + """Test creating an STT dataset with samples.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "test_stt_dataset_create", + "description": "Test STT dataset", + "language": "en", + "samples": [ + {"object_store_url": "s3://bucket/audio1.mp3"}, + { + "object_store_url": "s3://bucket/audio2.mp3", + "ground_truth": "Hello world", + }, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["name"] == "test_stt_dataset_create" + assert data["description"] == "Test STT dataset" + assert data["type"] == "stt" + assert data["language"] == "en" + assert data["sample_count"] == 2 + assert data["dataset_metadata"]["has_ground_truth_count"] == 1 + + def test_create_stt_dataset_minimal( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test creating an STT dataset with minimal fields.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "minimal_stt_dataset", + "samples": [ + {"object_store_url": "s3://bucket/audio.mp3"}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 200, response.text + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["name"] == "minimal_stt_dataset" + assert data["description"] is None + assert data["language"] is None + assert data["sample_count"] == 1 + + def test_create_stt_dataset_empty_samples( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test creating an STT dataset with empty samples fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "empty_samples_dataset", + "samples": [], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 422 + + def test_create_stt_dataset_missing_name( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test creating an STT dataset without name fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "samples": [ + {"object_store_url": "s3://bucket/audio.mp3"}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 422 + + def test_create_stt_dataset_without_authentication( + self, + client: TestClient, + ) -> None: + """Test creating an STT dataset without authentication fails.""" + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "unauthenticated_dataset", + "samples": [ + {"object_store_url": "s3://bucket/audio.mp3"}, + ], + }, + ) + + assert response.status_code == 401 + + def test_create_stt_dataset_duplicate_name( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test creating an STT dataset with duplicate name fails.""" + # Create first dataset + create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="duplicate_name_test", + ) + + # Try to create another with same name + response = client.post( + "/api/v1/evaluations/stt/datasets", + json={ + "name": "duplicate_name_test", + "samples": [ + {"object_store_url": "s3://bucket/audio.mp3"}, + ], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 400 + response_data = response.json() + error_str = response_data.get("detail", response_data.get("error", "")) + assert "already exists" in error_str.lower() + + +class TestSTTDatasetList: + """Test GET /evaluations/stt/datasets endpoint.""" + + def test_list_stt_datasets_empty( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test listing STT datasets when none exist.""" + response = client.get( + "/api/v1/evaluations/stt/datasets", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert isinstance(response_data["data"], list) + + def test_list_stt_datasets_with_data( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test listing STT datasets with data.""" + # Create test datasets + dataset1 = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="list_test_dataset_1", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset1.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + dataset2 = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="list_test_dataset_2", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset2.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + response = client.get( + "/api/v1/evaluations/stt/datasets", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + assert len(data) >= 2 + + # Check that our datasets are in the list + names = [d["name"] for d in data] + assert "list_test_dataset_1" in names + assert "list_test_dataset_2" in names + + def test_list_stt_datasets_pagination( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test pagination for listing STT datasets.""" + # Create multiple datasets + for i in range(5): + create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name=f"pagination_test_dataset_{i}", + ) + + # Test with limit + response = client.get( + "/api/v1/evaluations/stt/datasets", + params={"limit": 2, "offset": 0}, + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert len(response_data["data"]) == 2 + assert response_data["metadata"]["limit"] == 2 + assert response_data["metadata"]["offset"] == 0 + + +class TestSTTDatasetGet: + """Test GET /evaluations/stt/datasets/{dataset_id} endpoint.""" + + def test_get_stt_dataset_success( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test getting an STT dataset by ID.""" + dataset = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="get_test_dataset", + description="Test description", + ) + sample = create_test_stt_sample( + db=db, + dataset_id=dataset.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ground_truth="Test transcription", + ) + + response = client.get( + f"/api/v1/evaluations/stt/datasets/{dataset.id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + data = response_data["data"] + + assert data["id"] == dataset.id + assert data["name"] == "get_test_dataset" + assert data["description"] == "Test description" + assert data["type"] == "stt" + assert len(data["samples"]) == 1 + assert data["samples"][0]["id"] == sample.id + assert data["samples"][0]["ground_truth"] == "Test transcription" + + def test_get_stt_dataset_not_found( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test getting a non-existent STT dataset.""" + response = client.get( + "/api/v1/evaluations/stt/datasets/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + + def test_get_stt_dataset_without_samples( + self, + client: TestClient, + user_api_key_header: dict[str, str], + db: Session, + user_api_key: TestAuthContext, + ) -> None: + """Test getting an STT dataset without including samples.""" + dataset = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="get_no_samples_dataset", + ) + create_test_stt_sample( + db=db, + dataset_id=dataset.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + ) + + response = client.get( + f"/api/v1/evaluations/stt/datasets/{dataset.id}", + params={"include_samples": False}, + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + data = response_data["data"] + + assert data["id"] == dataset.id + assert data["samples"] == [] + assert data["sample_count"] == 1 # Count should still be correct + + +class TestSTTEvaluationRun: + """Test STT evaluation run endpoints.""" + + @pytest.fixture + def test_dataset_with_samples( + self, db: Session, user_api_key: TestAuthContext + ) -> EvaluationDataset: + """Create a test dataset with samples for evaluation.""" + dataset = create_test_stt_dataset( + db=db, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + name="eval_test_dataset", + ) + # Create some samples + for i in range(3): + create_test_stt_sample( + db=db, + dataset_id=dataset.id, + organization_id=user_api_key.organization_id, + project_id=user_api_key.project_id, + object_store_url=f"s3://bucket/audio_{i}.mp3", + ) + return dataset + + def test_start_stt_evaluation_invalid_dataset( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test starting an STT evaluation with invalid dataset ID.""" + response = client.post( + "/api/v1/evaluations/stt/runs", + json={ + "run_name": "test_run", + "dataset_id": 99999, + "providers": ["gemini-2.5-pro"], + }, + headers=user_api_key_header, + ) + + assert response.status_code == 404 + response_data = response.json() + error_str = response_data.get("detail", response_data.get("error", "")) + assert "not found" in error_str.lower() + + def test_start_stt_evaluation_without_authentication( + self, + client: TestClient, + ) -> None: + """Test starting an STT evaluation without authentication.""" + response = client.post( + "/api/v1/evaluations/stt/runs", + json={ + "run_name": "test_run", + "dataset_id": 1, + "providers": ["gemini-2.5-pro"], + }, + ) + + assert response.status_code == 401 + + def test_list_stt_runs_empty( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test listing STT runs when none exist.""" + response = client.get( + "/api/v1/evaluations/stt/runs", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["success"] is True + assert isinstance(response_data["data"], list) + + def test_get_stt_run_not_found( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test getting a non-existent STT run.""" + response = client.get( + "/api/v1/evaluations/stt/runs/99999", + headers=user_api_key_header, + ) + + assert response.status_code == 404 + + +class TestSTTResultFeedback: + """Test STT result feedback endpoint.""" + + def test_update_feedback_not_found( + self, + client: TestClient, + user_api_key_header: dict[str, str], + ) -> None: + """Test updating feedback for non-existent result.""" + response = client.patch( + "/api/v1/evaluations/stt/results/99999", + json={ + "is_correct": True, + "comment": "Test comment", + }, + headers=user_api_key_header, + ) + + assert response.status_code == 404 + + def test_update_feedback_without_authentication( + self, + client: TestClient, + ) -> None: + """Test updating feedback without authentication.""" + response = client.patch( + "/api/v1/evaluations/stt/results/1", + json={ + "is_correct": True, + }, + ) + + assert response.status_code == 401 From 78fd206b4b31c67a5a60b5952649fef2e9af5cce Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 22:03:51 +0530 Subject: [PATCH 08/36] added storage to core for easy reuse --- .../app/api/routes/stt_evaluations/dataset.py | 31 +-- backend/app/core/storage_utils.py | 137 +++++++------- .../app/services/stt_evaluations/dataset.py | 177 ++++++++++++++++++ backend/uv.lock | 151 ++++++++++----- 4 files changed, 356 insertions(+), 140 deletions(-) create mode 100644 backend/app/services/stt_evaluations/dataset.py diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index 58577e892..cd456866d 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -1,15 +1,12 @@ """STT dataset API routes.""" import logging -from typing import Any from fastapi import APIRouter, Body, Depends, HTTPException, Query from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission from app.crud.stt_evaluations import ( - create_stt_dataset, - create_stt_samples, get_stt_dataset_by_id, list_stt_datasets, get_samples_by_dataset_id, @@ -19,9 +16,9 @@ STTDatasetCreate, STTDatasetPublic, STTDatasetWithSamples, - STTSampleCreate, STTSamplePublic, ) +from app.services.stt_evaluations.dataset import upload_stt_dataset from app.utils import APIResponse logger = logging.getLogger(__name__) @@ -48,34 +45,14 @@ def create_dataset( dataset_create: STTDatasetCreate = Body(...), ) -> APIResponse[STTDatasetPublic]: """Create an STT evaluation dataset.""" - logger.info( - f"[create_dataset] Creating STT dataset | " - f"name: {dataset_create.name}, sample_count: {len(dataset_create.samples)}" - ) - - # Create dataset - dataset = create_stt_dataset( + dataset, samples = upload_stt_dataset( session=_session, name=dataset_create.name, - org_id=auth_context.organization_.id, + samples=dataset_create.samples, + organization_id=auth_context.organization_.id, project_id=auth_context.project_.id, description=dataset_create.description, language=dataset_create.language, - dataset_metadata={ - "sample_count": len(dataset_create.samples), - "has_ground_truth_count": sum( - 1 for s in dataset_create.samples if s.ground_truth - ), - }, - ) - - # Create samples - samples = create_stt_samples( - session=_session, - dataset_id=dataset.id, - org_id=auth_context.organization_.id, - project_id=auth_context.project_.id, - samples=dataset_create.samples, ) return APIResponse.success_response( diff --git a/backend/app/core/storage_utils.py b/backend/app/core/storage_utils.py index 63830d7d0..6130ca881 100644 --- a/backend/app/core/storage_utils.py +++ b/backend/app/core/storage_utils.py @@ -19,20 +19,25 @@ logger = logging.getLogger(__name__) -def upload_csv_to_object_store( +def upload_to_object_store( storage: CloudStorage, - csv_content: bytes, + content: bytes, filename: str, - subdirectory: str = "datasets", + subdirectory: str, + content_type: str = "application/octet-stream", ) -> str | None: """ - Upload CSV content to object store. + Upload content to object store. + + This is the generic upload function that handles any content type. + Use this directly or through convenience wrappers like upload_csv_to_object_store. Args: storage: CloudStorage instance - csv_content: Raw CSV content as bytes - filename: Name of the file (can include timestamp) - subdirectory: Subdirectory path in object store (default: "datasets") + content: Raw content as bytes + filename: Name of the file + subdirectory: Subdirectory path in object store (e.g., "datasets", "stt_datasets") + content_type: MIME type of the content (default: "application/octet-stream") Returns: Object store URL as string if successful, None if failed @@ -42,47 +47,74 @@ def upload_csv_to_object_store( Callers should continue without object store URL when this returns None. """ logger.info( - f"[upload_csv_to_object_store] Preparing to upload '{filename}' | " - f"size={len(csv_content)} bytes, subdirectory='{subdirectory}'" + f"[upload_to_object_store] Preparing to upload '{filename}' | " + f"size={len(content)} bytes, subdirectory='{subdirectory}', " + f"content_type='{content_type}'" ) try: - # Create file path file_path = Path(subdirectory) / filename - # Create a mock UploadFile-like object for the storage put method - class CSVFile: - def __init__(self, content: bytes): - self.file = io.BytesIO(content) - self.content_type = "text/csv" - - csv_file = CSVFile(csv_content) + headers = Headers({"content-type": content_type}) + upload_file = UploadFile( + filename=filename, + file=BytesIO(content), + headers=headers, + ) - # Upload to object store - destination = storage.put(source=csv_file, file_path=file_path) + destination = storage.put(source=upload_file, file_path=file_path) object_store_url = str(destination) logger.info( - f"[upload_csv_to_object_store] Upload successful | " + f"[upload_to_object_store] Upload successful | " f"filename='{filename}', url='{object_store_url}'" ) return object_store_url except CloudStorageError as e: logger.warning( - f"[upload_csv_to_object_store] Upload failed for '{filename}': {e}. " + f"[upload_to_object_store] Upload failed for '{filename}': {e}. " "Continuing without object store storage." ) return None except Exception as e: logger.warning( - f"[upload_csv_to_object_store] Unexpected error uploading '{filename}': {e}. " + f"[upload_to_object_store] Unexpected error uploading '{filename}': {e}. " "Continuing without object store storage.", exc_info=True, ) return None +def upload_csv_to_object_store( + storage: CloudStorage, + csv_content: bytes, + filename: str, + subdirectory: str = "datasets", +) -> str | None: + """ + Upload CSV content to object store. + + Convenience wrapper around upload_to_object_store for CSV files. + + Args: + storage: CloudStorage instance + csv_content: Raw CSV content as bytes + filename: Name of the file (can include timestamp) + subdirectory: Subdirectory path in object store (default: "datasets") + + Returns: + Object store URL as string if successful, None if failed + """ + return upload_to_object_store( + storage=storage, + content=csv_content, + filename=filename, + subdirectory=subdirectory, + content_type="text/csv", + ) + + def upload_jsonl_to_object_store( storage: CloudStorage, results: list[dict], @@ -92,65 +124,28 @@ def upload_jsonl_to_object_store( """ Upload JSONL (JSON Lines) content to object store. + Convenience wrapper around upload_to_object_store for JSONL files. + Args: storage: CloudStorage instance results: List of dictionaries to be converted to JSONL filename: Name of the file - subdirectory: Subdirectory path in object store (e.g., "evaluation/batch-123") + subdirectory: Subdirectory path in object store Returns: Object store URL as string if successful, None if failed - - Note: - This function handles errors gracefully and returns None on failure. - Callers should continue without object store URL when this returns None. """ - logger.info( - f"[upload_jsonl_to_object_store] Preparing to upload '{filename}' | " - f"items={len(results)}, subdirectory='{subdirectory}'" + jsonl_content = "\n".join([json.dumps(result) for result in results]) + content_bytes = jsonl_content.encode("utf-8") + + return upload_to_object_store( + storage=storage, + content=content_bytes, + filename=filename, + subdirectory=subdirectory, + content_type="application/jsonl", ) - try: - # Create file path - file_path = Path(subdirectory) / filename - - # Convert results to JSONL - jsonl_content = "\n".join([json.dumps(result) for result in results]) - content_bytes = jsonl_content.encode("utf-8") - - # Create UploadFile-like object - headers = Headers({"content-type": "application/jsonl"}) - upload_file = UploadFile( - filename=filename, - file=BytesIO(content_bytes), - headers=headers, - ) - - # Upload to object store - destination = storage.put(source=upload_file, file_path=file_path) - object_store_url = str(destination) - - logger.info( - f"[upload_jsonl_to_object_store] Upload successful | " - f"filename='{filename}', url='{object_store_url}', " - f"size={len(content_bytes)} bytes" - ) - return object_store_url - - except CloudStorageError as e: - logger.warning( - f"[upload_jsonl_to_object_store] Upload failed for '{filename}': {e}. " - "Continuing without object store storage." - ) - return None - except Exception as e: - logger.warning( - f"[upload_jsonl_to_object_store] Unexpected error uploading '{filename}': {e}. " - "Continuing without object store storage.", - exc_info=True, - ) - return None - def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str: """ diff --git a/backend/app/services/stt_evaluations/dataset.py b/backend/app/services/stt_evaluations/dataset.py new file mode 100644 index 000000000..205ce9990 --- /dev/null +++ b/backend/app/services/stt_evaluations/dataset.py @@ -0,0 +1,177 @@ +"""Dataset management service for STT evaluations.""" + +import csv +import io +import logging + +from sqlmodel import Session + +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import ( + generate_timestamped_filename, + upload_csv_to_object_store, +) +from app.crud.stt_evaluations import ( + create_stt_dataset, + create_stt_samples, +) +from app.models import EvaluationDataset +from app.models.stt_evaluation import STTSample, STTSampleCreate + +logger = logging.getLogger(__name__) + + +def upload_stt_dataset( + session: Session, + name: str, + samples: list[STTSampleCreate], + organization_id: int, + project_id: int, + description: str | None = None, + language: str | None = None, +) -> tuple[EvaluationDataset, list[STTSample]]: + """ + Orchestrate STT dataset upload workflow. + + Steps: + 1. Convert samples to CSV format + 2. Upload CSV to object store + 3. Create dataset record in database + 4. Create sample records in database + + Args: + session: Database session + name: Dataset name + samples: List of STT samples to create + organization_id: Organization ID + project_id: Project ID + description: Optional dataset description + language: Optional default language code + + Returns: + Tuple of (created dataset, created samples) + """ + logger.info( + f"[upload_stt_dataset] Uploading STT dataset | name={name} | " + f"sample_count={len(samples)} | org_id={organization_id} | " + f"project_id={project_id}" + ) + + # Step 1: Convert samples to CSV and upload to object store + object_store_url = _upload_samples_to_object_store( + session=session, + project_id=project_id, + dataset_name=name, + samples=samples, + ) + + # Step 2: Calculate metadata + metadata = { + "sample_count": len(samples), + "has_ground_truth_count": sum(1 for s in samples if s.ground_truth), + } + + # Step 3: Create dataset record + dataset = create_stt_dataset( + session=session, + name=name, + org_id=organization_id, + project_id=project_id, + description=description, + language=language, + object_store_url=object_store_url, + dataset_metadata=metadata, + ) + + logger.info( + f"[upload_stt_dataset] Created dataset record | " + f"id={dataset.id} | name={name}" + ) + + # Step 4: Create sample records + created_samples = create_stt_samples( + session=session, + dataset_id=dataset.id, + org_id=organization_id, + project_id=project_id, + samples=samples, + ) + + logger.info( + f"[upload_stt_dataset] Created sample records | " + f"dataset_id={dataset.id} | sample_count={len(created_samples)}" + ) + + return dataset, created_samples + + +def _upload_samples_to_object_store( + session: Session, + project_id: int, + dataset_name: str, + samples: list[STTSampleCreate], +) -> str | None: + """ + Upload STT samples as CSV to object store. + + Args: + session: Database session + project_id: Project ID for storage credentials + dataset_name: Dataset name for filename + samples: List of samples to upload + + Returns: + Object store URL if successful, None otherwise + """ + try: + storage = get_cloud_storage(session=session, project_id=project_id) + + # Convert samples to CSV format + csv_content = _samples_to_csv(samples) + + # Generate filename and upload + filename = generate_timestamped_filename(dataset_name, "csv") + object_store_url = upload_csv_to_object_store( + storage=storage, + csv_content=csv_content, + filename=filename, + subdirectory="stt_datasets", + ) + + if object_store_url: + logger.info( + f"[_upload_samples_to_object_store] Upload successful | " + f"url={object_store_url}" + ) + else: + logger.info( + "[_upload_samples_to_object_store] Upload returned None | " + "continuing without object store storage" + ) + + return object_store_url + + except Exception as e: + logger.warning( + f"[_upload_samples_to_object_store] Failed to upload | {e}", + exc_info=True, + ) + return None + + +def _samples_to_csv(samples: list[STTSampleCreate]) -> bytes: + """ + Convert STT samples to CSV format. + + Args: + samples: List of samples + + Returns: + CSV content as bytes + """ + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(["object_store_url", "ground_truth"]) + for sample in samples: + writer.writerow([sample.object_store_url, sample.ground_truth or ""]) + return output.getvalue().encode("utf-8") diff --git a/backend/uv.lock b/backend/uv.lock index fb79c631c..667e43f0d 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12, <4.0" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -206,6 +206,7 @@ dependencies = [ { name = "emails" }, { name = "fastapi", extra = ["standard"] }, { name = "flower" }, + { name = "google-genai" }, { name = "httpx" }, { name = "jinja2" }, { name = "langfuse" }, @@ -224,6 +225,7 @@ dependencies = [ { name = "pytest" }, { name = "python-multipart" }, { name = "redis" }, + { name = "requests" }, { name = "scikit-learn" }, { name = "sentry-sdk", extra = ["fastapi"] }, { name = "sqlmodel" }, @@ -252,6 +254,7 @@ requires-dist = [ { name = "emails", specifier = ">=0.6,<1.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.116.0" }, { name = "flower", specifier = ">=2.0.1" }, + { name = "google-genai", specifier = ">=1.0.0" }, { name = "httpx", specifier = ">=0.25.1,<1.0.0" }, { name = "jinja2", specifier = ">=3.1.4,<4.0.0" }, { name = "langfuse", specifier = "==2.60.3" }, @@ -270,6 +273,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.4.4" }, { name = "python-multipart", specifier = ">=0.0.22,<1.0.0" }, { name = "redis", specifier = ">=5.0.0,<6.0.0" }, + { name = "requests", specifier = ">=2.31.0" }, { name = "scikit-learn", specifier = ">=1.7.1" }, { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=2.20.0" }, { name = "sqlmodel", specifier = ">=0.0.21,<1.0.0" }, @@ -1108,6 +1112,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/c9/97cc5aae1648dcb851958a3ddf73ccd7dbe5650d95203ecb4d7720b4cdbf/fsspec-2026.1.0-py3-none-any.whl", hash = "sha256:cb76aa913c2285a3b49bdd5fc55b1d7c708d7208126b60f2eb8194fe1b4cbdcc", size = 201838, upload-time = "2026-01-09T15:21:34.041Z" }, ] +[[package]] +name = "google-auth" +version = "2.48.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0c/41/242044323fbd746615884b1c16639749e73665b718209946ebad7ba8a813/google_auth-2.48.0.tar.gz", hash = "sha256:4f7e706b0cd3208a3d940a19a822c37a476ddba5450156c3e6624a71f7c841ce", size = 326522, upload-time = "2026-01-26T19:22:47.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/1d/d6466de3a5249d35e832a52834115ca9d1d0de6abc22065f049707516d47/google_auth-2.48.0-py3-none-any.whl", hash = "sha256:2e2a537873d449434252a9632c28bfc268b0adb1e53f9fb62afc5333a975903f", size = 236499, upload-time = "2026-01-26T19:22:45.099Z" }, +] + +[package.optional-dependencies] +requests = [ + { name = "requests" }, +] + +[[package]] +name = "google-genai" +version = "1.61.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "google-auth", extra = ["requests"] }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "sniffio" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/38/421cd7e70952a536be87a0249409f87297d84f523754a25b08fe94b97e7f/google_genai-1.61.0.tar.gz", hash = "sha256:5773a4e8ad5b2ebcd54a633a67d8e9c4f413032fef07977ee47ffa34a6d3bbdf", size = 489672, upload-time = "2026-01-30T20:50:27.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/87/78dd70cb59f7acf3350f53c5144a7aa7bc39c6f425cd7dc1224b59fcdac3/google_genai-1.61.0-py3-none-any.whl", hash = "sha256:cb073ef8287581476c1c3f4d8e735426ee34478e500a56deef218fa93071e3ca", size = 721948, upload-time = "2026-01-30T20:50:25.551Z" }, +] + [[package]] name = "greenlet" version = "3.3.1" @@ -1117,6 +1161,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/c8/9d76a66421d1ae24340dfae7e79c313957f6e3195c144d2c73333b5bfe34/greenlet-3.3.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:7e806ca53acf6d15a888405880766ec84721aa4181261cd11a457dfe9a7a4975", size = 276443, upload-time = "2026-01-23T15:30:10.066Z" }, { url = "https://files.pythonhosted.org/packages/81/99/401ff34bb3c032d1f10477d199724f5e5f6fbfb59816ad1455c79c1eb8e7/greenlet-3.3.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d842c94b9155f1c9b3058036c24ffb8ff78b428414a19792b2380be9cecf4f36", size = 597359, upload-time = "2026-01-23T16:00:57.394Z" }, { url = "https://files.pythonhosted.org/packages/2b/bc/4dcc0871ed557792d304f50be0f7487a14e017952ec689effe2180a6ff35/greenlet-3.3.1-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20fedaadd422fa02695f82093f9a98bad3dab5fcda793c658b945fcde2ab27ba", size = 607805, upload-time = "2026-01-23T16:05:28.068Z" }, + { url = "https://files.pythonhosted.org/packages/3b/cd/7a7ca57588dac3389e97f7c9521cb6641fd8b6602faf1eaa4188384757df/greenlet-3.3.1-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c620051669fd04ac6b60ebc70478210119c56e2d5d5df848baec4312e260e4ca", size = 622363, upload-time = "2026-01-23T16:15:54.754Z" }, { url = "https://files.pythonhosted.org/packages/cf/05/821587cf19e2ce1f2b24945d890b164401e5085f9d09cbd969b0c193cd20/greenlet-3.3.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14194f5f4305800ff329cbf02c5fcc88f01886cadd29941b807668a45f0d2336", size = 609947, upload-time = "2026-01-23T15:32:51.004Z" }, { url = "https://files.pythonhosted.org/packages/a4/52/ee8c46ed9f8babaa93a19e577f26e3d28a519feac6350ed6f25f1afee7e9/greenlet-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7b2fe4150a0cf59f847a67db8c155ac36aed89080a6a639e9f16df5d6c6096f1", size = 1567487, upload-time = "2026-01-23T16:04:22.125Z" }, { url = "https://files.pythonhosted.org/packages/8f/7c/456a74f07029597626f3a6db71b273a3632aecb9afafeeca452cfa633197/greenlet-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:49f4ad195d45f4a66a0eb9c1ba4832bb380570d361912fa3554746830d332149", size = 1636087, upload-time = "2026-01-23T15:33:47.486Z" }, @@ -1125,6 +1170,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/ab/d26750f2b7242c2b90ea2ad71de70cfcd73a948a49513188a0fc0d6fc15a/greenlet-3.3.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:7ab327905cabb0622adca5971e488064e35115430cec2c35a50fd36e72a315b3", size = 275205, upload-time = "2026-01-23T15:30:24.556Z" }, { url = "https://files.pythonhosted.org/packages/10/d3/be7d19e8fad7c5a78eeefb2d896a08cd4643e1e90c605c4be3b46264998f/greenlet-3.3.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65be2f026ca6a176f88fb935ee23c18333ccea97048076aef4db1ef5bc0713ac", size = 599284, upload-time = "2026-01-23T16:00:58.584Z" }, { url = "https://files.pythonhosted.org/packages/ae/21/fe703aaa056fdb0f17e5afd4b5c80195bbdab701208918938bd15b00d39b/greenlet-3.3.1-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7a3ae05b3d225b4155bda56b072ceb09d05e974bc74be6c3fc15463cf69f33fd", size = 610274, upload-time = "2026-01-23T16:05:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/06/00/95df0b6a935103c0452dad2203f5be8377e551b8466a29650c4c5a5af6cc/greenlet-3.3.1-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:12184c61e5d64268a160226fb4818af4df02cfead8379d7f8b99a56c3a54ff3e", size = 624375, upload-time = "2026-01-23T16:15:55.915Z" }, { url = "https://files.pythonhosted.org/packages/cb/86/5c6ab23bb3c28c21ed6bebad006515cfe08b04613eb105ca0041fecca852/greenlet-3.3.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6423481193bbbe871313de5fd06a082f2649e7ce6e08015d2a76c1e9186ca5b3", size = 612904, upload-time = "2026-01-23T15:32:52.317Z" }, { url = "https://files.pythonhosted.org/packages/c2/f3/7949994264e22639e40718c2daf6f6df5169bf48fb038c008a489ec53a50/greenlet-3.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:33a956fe78bbbda82bfc95e128d61129b32d66bcf0a20a1f0c08aa4839ffa951", size = 1567316, upload-time = "2026-01-23T16:04:23.316Z" }, { url = "https://files.pythonhosted.org/packages/8d/6e/d73c94d13b6465e9f7cd6231c68abde838bb22408596c05d9059830b7872/greenlet-3.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b065d3284be43728dd280f6f9a13990b56470b81be20375a207cdc814a983f2", size = 1636549, upload-time = "2026-01-23T15:33:48.643Z" }, @@ -1133,6 +1179,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/fb/011c7c717213182caf78084a9bea51c8590b0afda98001f69d9f853a495b/greenlet-3.3.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:bd59acd8529b372775cd0fcbc5f420ae20681c5b045ce25bd453ed8455ab99b5", size = 275737, upload-time = "2026-01-23T15:32:16.889Z" }, { url = "https://files.pythonhosted.org/packages/41/2e/a3a417d620363fdbb08a48b1dd582956a46a61bf8fd27ee8164f9dfe87c2/greenlet-3.3.1-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b31c05dd84ef6871dd47120386aed35323c944d86c3d91a17c4b8d23df62f15b", size = 646422, upload-time = "2026-01-23T16:01:00.354Z" }, { url = "https://files.pythonhosted.org/packages/b4/09/c6c4a0db47defafd2d6bab8ddfe47ad19963b4e30f5bed84d75328059f8c/greenlet-3.3.1-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:02925a0bfffc41e542c70aa14c7eda3593e4d7e274bfcccca1827e6c0875902e", size = 658219, upload-time = "2026-01-23T16:05:30.956Z" }, + { url = "https://files.pythonhosted.org/packages/e2/89/b95f2ddcc5f3c2bc09c8ee8d77be312df7f9e7175703ab780f2014a0e781/greenlet-3.3.1-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3e0f3878ca3a3ff63ab4ea478585942b53df66ddde327b59ecb191b19dbbd62d", size = 671455, upload-time = "2026-01-23T16:15:57.232Z" }, { url = "https://files.pythonhosted.org/packages/80/38/9d42d60dffb04b45f03dbab9430898352dba277758640751dc5cc316c521/greenlet-3.3.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34a729e2e4e4ffe9ae2408d5ecaf12f944853f40ad724929b7585bca808a9d6f", size = 660237, upload-time = "2026-01-23T15:32:53.967Z" }, { url = "https://files.pythonhosted.org/packages/96/61/373c30b7197f9e756e4c81ae90a8d55dc3598c17673f91f4d31c3c689c3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aec9ab04e82918e623415947921dea15851b152b822661cce3f8e4393c3df683", size = 1615261, upload-time = "2026-01-23T16:04:25.066Z" }, { url = "https://files.pythonhosted.org/packages/fd/d3/ca534310343f5945316f9451e953dcd89b36fe7a19de652a1dc5a0eeef3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:71c767cf281a80d02b6c1bdc41c9468e1f5a494fb11bc8688c360524e273d7b1", size = 1683719, upload-time = "2026-01-23T15:33:50.61Z" }, @@ -1141,6 +1188,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/24/cbbec49bacdcc9ec652a81d3efef7b59f326697e7edf6ed775a5e08e54c2/greenlet-3.3.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:3e63252943c921b90abb035ebe9de832c436401d9c45f262d80e2d06cc659242", size = 282706, upload-time = "2026-01-23T15:33:05.525Z" }, { url = "https://files.pythonhosted.org/packages/86/2e/4f2b9323c144c4fe8842a4e0d92121465485c3c2c5b9e9b30a52e80f523f/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76e39058e68eb125de10c92524573924e827927df5d3891fbc97bd55764a8774", size = 651209, upload-time = "2026-01-23T16:01:01.517Z" }, { url = "https://files.pythonhosted.org/packages/d9/87/50ca60e515f5bb55a2fbc5f0c9b5b156de7d2fc51a0a69abc9d23914a237/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c9f9d5e7a9310b7a2f416dd13d2e3fd8b42d803968ea580b7c0f322ccb389b97", size = 654300, upload-time = "2026-01-23T16:05:32.199Z" }, + { url = "https://files.pythonhosted.org/packages/7c/25/c51a63f3f463171e09cb586eb64db0861eb06667ab01a7968371a24c4f3b/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4b9721549a95db96689458a1e0ae32412ca18776ed004463df3a9299c1b257ab", size = 662574, upload-time = "2026-01-23T16:15:58.364Z" }, { url = "https://files.pythonhosted.org/packages/1d/94/74310866dfa2b73dd08659a3d18762f83985ad3281901ba0ee9a815194fb/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92497c78adf3ac703b57f1e3813c2d874f27f71a178f9ea5887855da413cd6d2", size = 653842, upload-time = "2026-01-23T15:32:55.671Z" }, { url = "https://files.pythonhosted.org/packages/97/43/8bf0ffa3d498eeee4c58c212a3905dd6146c01c8dc0b0a046481ca29b18c/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ed6b402bc74d6557a705e197d47f9063733091ed6357b3de33619d8a8d93ac53", size = 1614917, upload-time = "2026-01-23T16:04:26.276Z" }, { url = "https://files.pythonhosted.org/packages/89/90/a3be7a5f378fc6e84abe4dcfb2ba32b07786861172e502388b4c90000d1b/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:59913f1e5ada20fde795ba906916aea25d442abcc0593fba7e26c92b7ad76249", size = 1676092, upload-time = "2026-01-23T15:33:52.176Z" }, @@ -2404,6 +2452,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/60/22c9716033ced1ee1d800457126c4c79652a4ed635b0554c1d93742cc0a1/py_zerox-0.0.7-py3-none-any.whl", hash = "sha256:7b7d92cb6fafec91a94b63ba3c039b643fb3ee83545b15fa330ec07dd52f2058", size = 23347, upload-time = "2024-10-21T16:03:33.406Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -3013,6 +3082,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/02/fa464cdfbe6b26e0600b62c528b72d8608f5cc49f96b8d6e38c95d60c676/rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3", size = 226532, upload-time = "2025-11-30T20:24:14.634Z" }, ] +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + [[package]] name = "ruff" version = "0.14.14" @@ -3639,47 +3720,33 @@ wheels = [ [[package]] name = "websockets" -version = "16.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5", size = 179346, upload-time = "2026-01-10T09:23:47.181Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00", size = 177365, upload-time = "2026-01-10T09:22:46.787Z" }, - { url = "https://files.pythonhosted.org/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79", size = 175038, upload-time = "2026-01-10T09:22:47.999Z" }, - { url = "https://files.pythonhosted.org/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39", size = 175328, upload-time = "2026-01-10T09:22:49.809Z" }, - { url = "https://files.pythonhosted.org/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c", size = 184915, upload-time = "2026-01-10T09:22:51.071Z" }, - { url = "https://files.pythonhosted.org/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f", size = 186152, upload-time = "2026-01-10T09:22:52.224Z" }, - { url = "https://files.pythonhosted.org/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1", size = 185583, upload-time = "2026-01-10T09:22:53.443Z" }, - { url = "https://files.pythonhosted.org/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2", size = 184880, upload-time = "2026-01-10T09:22:55.033Z" }, - { url = "https://files.pythonhosted.org/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89", size = 178261, upload-time = "2026-01-10T09:22:56.251Z" }, - { url = "https://files.pythonhosted.org/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea", size = 178693, upload-time = "2026-01-10T09:22:57.478Z" }, - { url = "https://files.pythonhosted.org/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9", size = 177364, upload-time = "2026-01-10T09:22:59.333Z" }, - { url = "https://files.pythonhosted.org/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230", size = 175039, upload-time = "2026-01-10T09:23:01.171Z" }, - { url = "https://files.pythonhosted.org/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c", size = 175323, upload-time = "2026-01-10T09:23:02.341Z" }, - { url = "https://files.pythonhosted.org/packages/bd/28/0a25ee5342eb5d5f297d992a77e56892ecb65e7854c7898fb7d35e9b33bd/websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5", size = 184975, upload-time = "2026-01-10T09:23:03.756Z" }, - { url = "https://files.pythonhosted.org/packages/f9/66/27ea52741752f5107c2e41fda05e8395a682a1e11c4e592a809a90c6a506/websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82", size = 186203, upload-time = "2026-01-10T09:23:05.01Z" }, - { url = "https://files.pythonhosted.org/packages/37/e5/8e32857371406a757816a2b471939d51c463509be73fa538216ea52b792a/websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8", size = 185653, upload-time = "2026-01-10T09:23:06.301Z" }, - { url = "https://files.pythonhosted.org/packages/9b/67/f926bac29882894669368dc73f4da900fcdf47955d0a0185d60103df5737/websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f", size = 184920, upload-time = "2026-01-10T09:23:07.492Z" }, - { url = "https://files.pythonhosted.org/packages/3c/a1/3d6ccdcd125b0a42a311bcd15a7f705d688f73b2a22d8cf1c0875d35d34a/websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a", size = 178255, upload-time = "2026-01-10T09:23:09.245Z" }, - { url = "https://files.pythonhosted.org/packages/6b/ae/90366304d7c2ce80f9b826096a9e9048b4bb760e44d3b873bb272cba696b/websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156", size = 178689, upload-time = "2026-01-10T09:23:10.483Z" }, - { url = "https://files.pythonhosted.org/packages/f3/1d/e88022630271f5bd349ed82417136281931e558d628dd52c4d8621b4a0b2/websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0", size = 177406, upload-time = "2026-01-10T09:23:12.178Z" }, - { url = "https://files.pythonhosted.org/packages/f2/78/e63be1bf0724eeb4616efb1ae1c9044f7c3953b7957799abb5915bffd38e/websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904", size = 175085, upload-time = "2026-01-10T09:23:13.511Z" }, - { url = "https://files.pythonhosted.org/packages/bb/f4/d3c9220d818ee955ae390cf319a7c7a467beceb24f05ee7aaaa2414345ba/websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4", size = 175328, upload-time = "2026-01-10T09:23:14.727Z" }, - { url = "https://files.pythonhosted.org/packages/63/bc/d3e208028de777087e6fb2b122051a6ff7bbcca0d6df9d9c2bf1dd869ae9/websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e", size = 185044, upload-time = "2026-01-10T09:23:15.939Z" }, - { url = "https://files.pythonhosted.org/packages/ad/6e/9a0927ac24bd33a0a9af834d89e0abc7cfd8e13bed17a86407a66773cc0e/websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4", size = 186279, upload-time = "2026-01-10T09:23:17.148Z" }, - { url = "https://files.pythonhosted.org/packages/b9/ca/bf1c68440d7a868180e11be653c85959502efd3a709323230314fda6e0b3/websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1", size = 185711, upload-time = "2026-01-10T09:23:18.372Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f8/fdc34643a989561f217bb477cbc47a3a07212cbda91c0e4389c43c296ebf/websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3", size = 184982, upload-time = "2026-01-10T09:23:19.652Z" }, - { url = "https://files.pythonhosted.org/packages/dd/d1/574fa27e233764dbac9c52730d63fcf2823b16f0856b3329fc6268d6ae4f/websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8", size = 177915, upload-time = "2026-01-10T09:23:21.458Z" }, - { url = "https://files.pythonhosted.org/packages/8a/f1/ae6b937bf3126b5134ce1f482365fde31a357c784ac51852978768b5eff4/websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d", size = 178381, upload-time = "2026-01-10T09:23:22.715Z" }, - { url = "https://files.pythonhosted.org/packages/06/9b/f791d1db48403e1f0a27577a6beb37afae94254a8c6f08be4a23e4930bc0/websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244", size = 177737, upload-time = "2026-01-10T09:23:24.523Z" }, - { url = "https://files.pythonhosted.org/packages/bd/40/53ad02341fa33b3ce489023f635367a4ac98b73570102ad2cdd770dacc9a/websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e", size = 175268, upload-time = "2026-01-10T09:23:25.781Z" }, - { url = "https://files.pythonhosted.org/packages/74/9b/6158d4e459b984f949dcbbb0c5d270154c7618e11c01029b9bbd1bb4c4f9/websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641", size = 175486, upload-time = "2026-01-10T09:23:27.033Z" }, - { url = "https://files.pythonhosted.org/packages/e5/2d/7583b30208b639c8090206f95073646c2c9ffd66f44df967981a64f849ad/websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8", size = 185331, upload-time = "2026-01-10T09:23:28.259Z" }, - { url = "https://files.pythonhosted.org/packages/45/b0/cce3784eb519b7b5ad680d14b9673a31ab8dcb7aad8b64d81709d2430aa8/websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e", size = 186501, upload-time = "2026-01-10T09:23:29.449Z" }, - { url = "https://files.pythonhosted.org/packages/19/60/b8ebe4c7e89fb5f6cdf080623c9d92789a53636950f7abacfc33fe2b3135/websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944", size = 186062, upload-time = "2026-01-10T09:23:31.368Z" }, - { url = "https://files.pythonhosted.org/packages/88/a8/a080593f89b0138b6cba1b28f8df5673b5506f72879322288b031337c0b8/websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206", size = 185356, upload-time = "2026-01-10T09:23:32.627Z" }, - { url = "https://files.pythonhosted.org/packages/c2/b6/b9afed2afadddaf5ebb2afa801abf4b0868f42f8539bfe4b071b5266c9fe/websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6", size = 178085, upload-time = "2026-01-10T09:23:33.816Z" }, - { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, - { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, +version = "15.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/e6/26d09fab466b7ca9c7737474c52be4f76a40301b08362eb2dbc19dcc16c1/websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee", size = 177016, upload-time = "2025-03-05T20:03:41.606Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/6b/4545a0d843594f5d0771e86463606a3988b5a09ca5123136f8a76580dd63/websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3", size = 175437, upload-time = "2025-03-05T20:02:16.706Z" }, + { url = "https://files.pythonhosted.org/packages/f4/71/809a0f5f6a06522af902e0f2ea2757f71ead94610010cf570ab5c98e99ed/websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665", size = 173096, upload-time = "2025-03-05T20:02:18.832Z" }, + { url = "https://files.pythonhosted.org/packages/3d/69/1a681dd6f02180916f116894181eab8b2e25b31e484c5d0eae637ec01f7c/websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2", size = 173332, upload-time = "2025-03-05T20:02:20.187Z" }, + { url = "https://files.pythonhosted.org/packages/a6/02/0073b3952f5bce97eafbb35757f8d0d54812b6174ed8dd952aa08429bcc3/websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215", size = 183152, upload-time = "2025-03-05T20:02:22.286Z" }, + { url = "https://files.pythonhosted.org/packages/74/45/c205c8480eafd114b428284840da0b1be9ffd0e4f87338dc95dc6ff961a1/websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5", size = 182096, upload-time = "2025-03-05T20:02:24.368Z" }, + { url = "https://files.pythonhosted.org/packages/14/8f/aa61f528fba38578ec553c145857a181384c72b98156f858ca5c8e82d9d3/websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65", size = 182523, upload-time = "2025-03-05T20:02:25.669Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6d/0267396610add5bc0d0d3e77f546d4cd287200804fe02323797de77dbce9/websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe", size = 182790, upload-time = "2025-03-05T20:02:26.99Z" }, + { url = "https://files.pythonhosted.org/packages/02/05/c68c5adbf679cf610ae2f74a9b871ae84564462955d991178f95a1ddb7dd/websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4", size = 182165, upload-time = "2025-03-05T20:02:30.291Z" }, + { url = "https://files.pythonhosted.org/packages/29/93/bb672df7b2f5faac89761cb5fa34f5cec45a4026c383a4b5761c6cea5c16/websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597", size = 182160, upload-time = "2025-03-05T20:02:31.634Z" }, + { url = "https://files.pythonhosted.org/packages/ff/83/de1f7709376dc3ca9b7eeb4b9a07b4526b14876b6d372a4dc62312bebee0/websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9", size = 176395, upload-time = "2025-03-05T20:02:33.017Z" }, + { url = "https://files.pythonhosted.org/packages/7d/71/abf2ebc3bbfa40f391ce1428c7168fb20582d0ff57019b69ea20fa698043/websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7", size = 176841, upload-time = "2025-03-05T20:02:34.498Z" }, + { url = "https://files.pythonhosted.org/packages/cb/9f/51f0cf64471a9d2b4d0fc6c534f323b664e7095640c34562f5182e5a7195/websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931", size = 175440, upload-time = "2025-03-05T20:02:36.695Z" }, + { url = "https://files.pythonhosted.org/packages/8a/05/aa116ec9943c718905997412c5989f7ed671bc0188ee2ba89520e8765d7b/websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675", size = 173098, upload-time = "2025-03-05T20:02:37.985Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0b/33cef55ff24f2d92924923c99926dcce78e7bd922d649467f0eda8368923/websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151", size = 173329, upload-time = "2025-03-05T20:02:39.298Z" }, + { url = "https://files.pythonhosted.org/packages/31/1d/063b25dcc01faa8fada1469bdf769de3768b7044eac9d41f734fd7b6ad6d/websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22", size = 183111, upload-time = "2025-03-05T20:02:40.595Z" }, + { url = "https://files.pythonhosted.org/packages/93/53/9a87ee494a51bf63e4ec9241c1ccc4f7c2f45fff85d5bde2ff74fcb68b9e/websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f", size = 182054, upload-time = "2025-03-05T20:02:41.926Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b2/83a6ddf56cdcbad4e3d841fcc55d6ba7d19aeb89c50f24dd7e859ec0805f/websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8", size = 182496, upload-time = "2025-03-05T20:02:43.304Z" }, + { url = "https://files.pythonhosted.org/packages/98/41/e7038944ed0abf34c45aa4635ba28136f06052e08fc2168520bb8b25149f/websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375", size = 182829, upload-time = "2025-03-05T20:02:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/e0/17/de15b6158680c7623c6ef0db361da965ab25d813ae54fcfeae2e5b9ef910/websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d", size = 182217, upload-time = "2025-03-05T20:02:50.14Z" }, + { url = "https://files.pythonhosted.org/packages/33/2b/1f168cb6041853eef0362fb9554c3824367c5560cbdaad89ac40f8c2edfc/websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4", size = 182195, upload-time = "2025-03-05T20:02:51.561Z" }, + { url = "https://files.pythonhosted.org/packages/86/eb/20b6cdf273913d0ad05a6a14aed4b9a85591c18a987a3d47f20fa13dcc47/websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa", size = 176393, upload-time = "2025-03-05T20:02:53.814Z" }, + { url = "https://files.pythonhosted.org/packages/1b/6c/c65773d6cab416a64d191d6ee8a8b1c68a09970ea6909d16965d26bfed1e/websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561", size = 176837, upload-time = "2025-03-05T20:02:55.237Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] [[package]] From 4ac2ca61395255e6be76bb32f27321b298c2f300 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 22:30:08 +0530 Subject: [PATCH 09/36] cleanup for audio duration --- .../versions/043_add_stt_evaluation_tables.py | 6 --- .../app/api/routes/stt_evaluations/dataset.py | 1 - backend/app/crud/stt_evaluations/dataset.py | 37 ++++++++++++++++++- backend/app/crud/stt_evaluations/result.py | 1 - backend/app/models/stt_evaluation.py | 7 ---- .../app/services/stt_evaluations/dataset.py | 1 + 6 files changed, 37 insertions(+), 16 deletions(-) diff --git a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py index b76f2da9e..9263d1501 100644 --- a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -107,12 +107,6 @@ def upgrade(): nullable=True, comment="Reference transcription for comparison (optional)", ), - sa.Column( - "duration_seconds", - sa.Float(), - nullable=True, - comment="Audio duration in seconds", - ), sa.Column( "sample_metadata", postgresql.JSONB(astext_type=sa.Text()), diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index cd456866d..f3d048b93 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -146,7 +146,6 @@ def get_dataset( object_store_url=s.object_store_url, language=s.language, ground_truth=s.ground_truth, - duration_seconds=s.duration_seconds, sample_metadata=s.sample_metadata, dataset_id=s.dataset_id, organization_id=s.organization_id, diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index 8de773879..3014dcc0b 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -2,6 +2,8 @@ import logging from typing import Any +from urllib.parse import urlparse +import os from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select, func @@ -92,6 +94,32 @@ def create_stt_dataset( raise +def _extract_metadata_from_url(url: str) -> dict[str, Any]: + """Extract filename and extension from S3 URL. + + Args: + url: S3 URL of the audio file + + Returns: + dict with original_filename and file_extension + """ + try: + parsed = urlparse(url) + path = parsed.path + filename = os.path.basename(path) + _, extension = os.path.splitext(filename) + # Remove leading dot from extension + extension = extension.lstrip(".").lower() if extension else None + + return { + "original_filename": filename if filename else None, + "file_extension": extension, + } + except Exception as e: + logger.warning(f"[_extract_metadata_from_url] Failed to extract metadata: {e}") + return {} + + def create_stt_samples( *, session: Session, @@ -99,6 +127,7 @@ def create_stt_samples( org_id: int, project_id: int, samples: list[STTSampleCreate], + language: str | None = None, ) -> list[STTSample]: """Create STT samples for a dataset. @@ -108,21 +137,27 @@ def create_stt_samples( org_id: Organization ID project_id: Project ID samples: List of sample data + language: Language code from parent dataset Returns: list[STTSample]: Created samples """ logger.info( f"[create_stt_samples] Creating STT samples | " - f"dataset_id: {dataset_id}, sample_count: {len(samples)}" + f"dataset_id: {dataset_id}, sample_count: {len(samples)}, language: {language}" ) created_samples = [] for sample_data in samples: + # Extract metadata from URL + sample_metadata = _extract_metadata_from_url(sample_data.object_store_url) + sample = STTSample( object_store_url=sample_data.object_store_url, ground_truth=sample_data.ground_truth, + language=language, + sample_metadata=sample_metadata, dataset_id=dataset_id, organization_id=org_id, project_id=project_id, diff --git a/backend/app/crud/stt_evaluations/result.py b/backend/app/crud/stt_evaluations/result.py index 428f78529..4c3fd8a9a 100644 --- a/backend/app/crud/stt_evaluations/result.py +++ b/backend/app/crud/stt_evaluations/result.py @@ -212,7 +212,6 @@ def get_results_by_run_id( object_store_url=sample.object_store_url, language=sample.language, ground_truth=sample.ground_truth, - duration_seconds=sample.duration_seconds, sample_metadata=sample.sample_metadata, dataset_id=sample.dataset_id, organization_id=sample.organization_id, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 8286ccbbb..a5749f5d3 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -74,12 +74,6 @@ class STTSample(SQLModel, table=True): ) # Audio metadata - duration_seconds: float | None = SQLField( - default=None, - description="Audio duration in seconds", - sa_column_kwargs={"comment": "Audio duration in seconds"}, - ) - sample_metadata: dict[str, Any] | None = SQLField( default_factory=dict, sa_column=Column( @@ -282,7 +276,6 @@ class STTSamplePublic(BaseModel): object_store_url: str language: str | None ground_truth: str | None - duration_seconds: float | None sample_metadata: dict[str, Any] | None dataset_id: int organization_id: int diff --git a/backend/app/services/stt_evaluations/dataset.py b/backend/app/services/stt_evaluations/dataset.py index 205ce9990..5fda1ec63 100644 --- a/backend/app/services/stt_evaluations/dataset.py +++ b/backend/app/services/stt_evaluations/dataset.py @@ -95,6 +95,7 @@ def upload_stt_dataset( org_id=organization_id, project_id=project_id, samples=samples, + language=language, ) logger.info( From d8b531c498d2bb6cc7a031fc1ce2581a70d1bbc8 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Sat, 31 Jan 2026 22:51:14 +0530 Subject: [PATCH 10/36] first stab at fixing celery task to cron --- .../api/routes/stt_evaluations/evaluation.py | 73 +- backend/app/celery/tasks/stt_evaluation.py | 704 ------------------ backend/app/crud/evaluations/cron.py | 27 +- backend/app/crud/stt_evaluations/__init__.py | 6 + backend/app/crud/stt_evaluations/batch.py | 255 +++++++ backend/app/crud/stt_evaluations/cron.py | 345 +++++++++ 6 files changed, 687 insertions(+), 723 deletions(-) delete mode 100644 backend/app/celery/tasks/stt_evaluation.py create mode 100644 backend/app/crud/stt_evaluations/batch.py create mode 100644 backend/app/crud/stt_evaluations/cron.py diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index 1a976cab3..2dcff4ca2 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -2,18 +2,20 @@ import logging -from asgi_correlation_id import correlation_id from fastapi import APIRouter, Body, Depends, HTTPException, Query from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission -from app.celery.tasks.stt_evaluation import process_stt_evaluation from app.crud.stt_evaluations import ( create_stt_run, + create_stt_results, + get_samples_by_dataset_id, get_stt_dataset_by_id, get_stt_run_by_id, list_stt_runs, get_sample_count_for_dataset, + start_stt_evaluation_batch, + update_stt_run, ) from app.crud.stt_evaluations.result import get_results_by_run_id from app.models.stt_evaluation import ( @@ -88,21 +90,64 @@ def start_stt_evaluation( total_items=sample_count * len(run_create.providers), ) - # Enqueue Celery task - trace_id = correlation_id.get() or "" + # Get samples for the dataset + samples, _ = get_samples_by_dataset_id( + session=_session, + dataset_id=run_create.dataset_id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + limit=10000, # Get all samples + ) - process_stt_evaluation.apply_async( - kwargs={ - "evaluation_run_id": run.id, - "org_id": auth_context.organization_.id, - "project_id": auth_context.project_.id, - "trace_id": trace_id, - }, + # Create result records for each sample and provider + results = create_stt_results( + session=_session, + samples=samples, + evaluation_run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + providers=run_create.providers, ) - logger.info( - f"[start_stt_evaluation] STT evaluation queued | " - f"run_id: {run.id}, task queued" + # Extract result data for batch processing + result_refs = [ + {"id": r.id, "stt_sample_id": r.stt_sample_id, "provider": r.provider} + for r in results + ] + + # Submit batch synchronously + try: + batch_result = start_stt_evaluation_batch( + session=_session, + run=run, + samples=samples, + result_refs=result_refs, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + logger.info( + f"[start_stt_evaluation] STT evaluation batch submitted | " + f"run_id: {run.id}, batch_jobs: {batch_result.get('batch_jobs', {}).keys()}" + ) + except Exception as e: + logger.error( + f"[start_stt_evaluation] Batch submission failed | " + f"run_id: {run.id}, error: {str(e)}" + ) + update_stt_run( + session=_session, + run_id=run.id, + status="failed", + error_message=str(e), + ) + raise HTTPException(status_code=500, detail=f"Batch submission failed: {e}") + + # Refresh run to get updated status + run = get_stt_run_by_id( + session=_session, + run_id=run.id, + org_id=auth_context.organization_.id, + project_id=auth_context.project_.id, ) return APIResponse.success_response( diff --git a/backend/app/celery/tasks/stt_evaluation.py b/backend/app/celery/tasks/stt_evaluation.py deleted file mode 100644 index 5a3bab815..000000000 --- a/backend/app/celery/tasks/stt_evaluation.py +++ /dev/null @@ -1,704 +0,0 @@ -"""Celery tasks for STT evaluation processing.""" - -import logging -from typing import Any - -from asgi_correlation_id import correlation_id -from celery import current_task -from sqlmodel import Session, select - -from app.celery.celery_app import celery_app -from app.core.batch import BatchJobState, GeminiBatchProvider -from app.core.db import engine -from app.core.cloud.storage import get_cloud_storage -from app.crud.stt_evaluations import ( - get_stt_run_by_id, - get_samples_by_dataset_id, - get_stt_dataset_by_id, - update_stt_run, - create_stt_results, - update_stt_result, - count_results_by_status, -) -from app.models.stt_evaluation import STTResult, STTResultStatus -from app.services.stt_evaluations.gemini import ( - GeminiClient, - GeminiFilesManager, -) - -logger = logging.getLogger(__name__) - -# Default transcription prompt -DEFAULT_TRANSCRIPTION_PROMPT = ( - "Generate a verbatim transcript of the speech in this audio file. " - "Return only the transcription text without any formatting, timestamps, or metadata." -) - -# Provider name to Gemini model mapping -PROVIDER_MODEL_MAPPING: dict[str, str] = { - "gemini-2.5-pro": "models/gemini-2.5-pro", - "gemini-2.5-flash": "models/gemini-2.5-flash", - "gemini-2.0-flash": "models/gemini-2.0-flash", -} - -# Maximum number of polls (24 hours with 30s intervals) -MAX_POLL_COUNT = 2880 - - -def _get_model_for_provider(provider: str) -> str: - """Map provider name to Gemini model. - - Args: - provider: Provider name - - Returns: - str: Gemini model name - """ - return PROVIDER_MODEL_MAPPING.get(provider, f"models/{provider}") - - -def _build_batch_requests( - sample_file_mapping: list[tuple[int, int | None, str]], - prompt: str = DEFAULT_TRANSCRIPTION_PROMPT, -) -> list[dict[str, Any]]: - """Build JSONL batch request data from sample-file mappings. - - Each request follows the Gemini GenerateContentRequest format - with a text prompt and file_data reference. - - Args: - sample_file_mapping: List of (sample_id, result_id, google_file_uri) tuples - prompt: Transcription prompt - - Returns: - list[dict]: JSONL-compatible request dicts for GeminiBatchProvider - """ - return [ - { - "contents": [ - { - "parts": [ - {"text": prompt}, - {"file_data": {"file_uri": file_uri}}, - ], - "role": "user", - } - ], - } - for _, _, file_uri in sample_file_mapping - ] - - -@celery_app.task(bind=True, queue="low_priority") -def process_stt_evaluation( - self, - evaluation_run_id: int, - org_id: int, - project_id: int, - trace_id: str, - **kwargs, -) -> dict[str, Any]: - """Process an STT evaluation run. - - This task: - 1. Fetches the evaluation run and dataset samples - 2. Uploads audio files to Google Files API - 3. Creates batch requests via GeminiBatchProvider - 4. Submits batch jobs per provider - 5. Schedules polling task - - Args: - evaluation_run_id: ID of the evaluation run - org_id: Organization ID - project_id: Project ID - trace_id: Correlation/trace ID - - Returns: - dict: Result with batch job information - """ - task_id = current_task.request.id - correlation_id.set(trace_id) - - logger.info( - f"[process_stt_evaluation] Starting STT evaluation | " - f"run_id: {evaluation_run_id}, task_id: {task_id}" - ) - - try: - with Session(engine) as session: - # Get the evaluation run - run = get_stt_run_by_id( - session=session, - run_id=evaluation_run_id, - org_id=org_id, - project_id=project_id, - ) - - if not run: - logger.error( - f"[process_stt_evaluation] Run not found | run_id: {evaluation_run_id}" - ) - return {"success": False, "error": "Evaluation run not found"} - - # Get the dataset - dataset = get_stt_dataset_by_id( - session=session, - dataset_id=run.dataset_id, - org_id=org_id, - project_id=project_id, - ) - - if not dataset: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message="Dataset not found", - ) - return {"success": False, "error": "Dataset not found"} - - # Get all samples - samples, total = get_samples_by_dataset_id( - session=session, - dataset_id=run.dataset_id, - org_id=org_id, - project_id=project_id, - limit=10000, # Get all samples - ) - - if not samples: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message="No samples in dataset", - ) - return {"success": False, "error": "No samples in dataset"} - - # Update run with total items - providers = run.providers or ["gemini-2.5-pro"] - total_items = len(samples) * len(providers) - - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="processing", - ) - - # Create result records for each sample and provider - results = create_stt_results( - session=session, - samples=samples, - evaluation_run_id=evaluation_run_id, - org_id=org_id, - project_id=project_id, - providers=providers, - ) - - # Extract result data before session closes to avoid DetachedInstanceError - result_refs = [ - {"id": r.id, "stt_sample_id": r.stt_sample_id, "provider": r.provider} - for r in results - ] - - # Update total items - run.total_items = total_items - session.add(run) - session.commit() - - # Initialize Gemini client - gemini_client = GeminiClient.from_credentials( - session=session, - org_id=org_id, - project_id=project_id, - ) - - # Get cloud storage for S3 access - storage = get_cloud_storage(session=session, project_id=project_id) - - # Upload audio files to Google Files API - files_manager = GeminiFilesManager(gemini_client.client) - - sample_file_mapping = [] # [(sample_id, result_id, google_file_uri)] - - for sample in samples: - try: - # Get signed URL for S3 audio file - signed_url = storage.get_signed_url( - sample.object_store_url, expires_in=3600 - ) - - # Extract filename from URL - filename = sample.object_store_url.split("/")[-1] - - # Upload to Google Files API - google_file_uri = files_manager.upload_from_url( - signed_url=signed_url, - filename=filename, - ) - - # Find the result record for this sample - result_for_sample = next( - (r for r in result_refs if r["stt_sample_id"] == sample.id), - None, - ) - - sample_file_mapping.append( - ( - sample.id, - result_for_sample["id"] if result_for_sample else None, - google_file_uri, - ) - ) - - logger.info( - f"[process_stt_evaluation] Uploaded audio to Google | " - f"sample_id: {sample.id}, file_uri: {google_file_uri}" - ) - - except Exception as e: - logger.error( - f"[process_stt_evaluation] Failed to upload audio | " - f"sample_id: {sample.id}, error: {str(e)}" - ) - # Mark result as failed - for ref in result_refs: - if ref["stt_sample_id"] == sample.id: - update_stt_result( - session=session, - result_id=ref["id"], - status=STTResultStatus.FAILED.value, - error_message=f"Failed to upload audio: {str(e)}", - ) - - # Build batch requests from uploaded files - jsonl_data = _build_batch_requests(sample_file_mapping) - - # Process each provider using GeminiBatchProvider - batch_jobs: dict[str, str] = {} - - for provider in providers: - try: - model = _get_model_for_provider(provider) - batch_provider = GeminiBatchProvider( - client=gemini_client.client, model=model - ) - - batch_result = batch_provider.create_batch( - jsonl_data=jsonl_data, - config={ - "display_name": f"stt-eval-{evaluation_run_id}-{provider}", - "model": model, - }, - ) - - batch_jobs[provider] = batch_result["provider_batch_id"] - - logger.info( - f"[process_stt_evaluation] Batch job submitted | " - f"run_id: {evaluation_run_id}, provider: {provider}, " - f"batch_id: {batch_result['provider_batch_id']}" - ) - - except Exception as e: - logger.error( - f"[process_stt_evaluation] Failed to submit batch | " - f"provider: {provider}, error: {str(e)}" - ) - # Update results for this provider as failed - with Session(engine) as session: - for ref in result_refs: - if ref["provider"] == provider: - update_stt_result( - session=session, - result_id=ref["id"], - status=STTResultStatus.FAILED.value, - error_message=f"Batch submission failed: {str(e)}", - ) - - if not batch_jobs: - with Session(engine) as session: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message="All batch submissions failed", - ) - return {"success": False, "error": "All batch submissions failed"} - - # Store batch job info in run score for polling - with Session(engine) as session: - update_stt_run( - session=session, - run_id=evaluation_run_id, - score={ - "batch_jobs": batch_jobs, - "sample_file_mapping": [ - {"sample_id": s, "result_id": r, "file_uri": f} - for s, r, f in sample_file_mapping - ], - }, - ) - - # Schedule polling task - poll_stt_batch_status.apply_async( - kwargs={ - "evaluation_run_id": evaluation_run_id, - "org_id": org_id, - "project_id": project_id, - "trace_id": trace_id, - "batch_jobs": batch_jobs, - }, - countdown=30, # Wait 30 seconds before first poll - ) - - return { - "success": True, - "run_id": evaluation_run_id, - "batch_jobs": batch_jobs, - "sample_count": len(samples), - } - - except Exception as e: - logger.error( - f"[process_stt_evaluation] Failed to process evaluation | " - f"run_id: {evaluation_run_id}, error: {str(e)}", - exc_info=True, - ) - - with Session(engine) as session: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message=str(e), - ) - - return {"success": False, "error": str(e)} - - -@celery_app.task(bind=True, queue="low_priority") -def poll_stt_batch_status( - self, - evaluation_run_id: int, - org_id: int, - project_id: int, - trace_id: str, - batch_jobs: dict[str, str], - poll_count: int = 0, - **kwargs, -) -> dict[str, Any]: - """Poll Gemini batch job status using GeminiBatchProvider. - - Args: - evaluation_run_id: ID of the evaluation run - org_id: Organization ID - project_id: Project ID - trace_id: Correlation/trace ID - batch_jobs: Dict of provider -> batch_id - poll_count: Number of times we've polled - - Returns: - dict: Status information - """ - correlation_id.set(trace_id) - - if poll_count >= MAX_POLL_COUNT: - logger.error( - f"[poll_stt_batch_status] Polling timed out | " - f"run_id: {evaluation_run_id}" - ) - with Session(engine) as session: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message="Batch processing timed out after 24 hours", - ) - return {"success": False, "error": "Timeout"} - - logger.info( - f"[poll_stt_batch_status] Polling batch status | " - f"run_id: {evaluation_run_id}, poll_count: {poll_count}" - ) - - try: - with Session(engine) as session: - # Initialize Gemini client and batch provider - gemini_client = GeminiClient.from_credentials( - session=session, - org_id=org_id, - project_id=project_id, - ) - batch_provider = GeminiBatchProvider(client=gemini_client.client) - - all_complete = True - any_success = False - - for provider, batch_id in batch_jobs.items(): - status = batch_provider.get_batch_status(batch_id) - provider_status = status["provider_status"] - - logger.info( - f"[poll_stt_batch_status] Batch status | " - f"provider: {provider}, batch_id: {batch_id}, " - f"state: {provider_status}" - ) - - is_terminal = provider_status in { - BatchJobState.SUCCEEDED.value, - BatchJobState.FAILED.value, - BatchJobState.CANCELLED.value, - BatchJobState.EXPIRED.value, - } - - if not is_terminal: - all_complete = False - elif provider_status == BatchJobState.SUCCEEDED.value: - any_success = True - - if not all_complete: - # Re-schedule polling - poll_stt_batch_status.apply_async( - kwargs={ - "evaluation_run_id": evaluation_run_id, - "org_id": org_id, - "project_id": project_id, - "trace_id": trace_id, - "batch_jobs": batch_jobs, - "poll_count": poll_count + 1, - }, - countdown=30, - ) - return {"success": True, "status": "polling", "poll_count": poll_count} - - # All batches complete - process results - if any_success: - process_stt_batch_results.apply_async( - kwargs={ - "evaluation_run_id": evaluation_run_id, - "org_id": org_id, - "project_id": project_id, - "trace_id": trace_id, - "batch_jobs": batch_jobs, - }, - ) - return {"success": True, "status": "processing_results"} - else: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message="All batch jobs failed", - ) - return {"success": False, "error": "All batch jobs failed"} - - except Exception as e: - logger.error( - f"[poll_stt_batch_status] Polling failed | " - f"run_id: {evaluation_run_id}, error: {str(e)}", - exc_info=True, - ) - - # Re-schedule polling (might be temporary error) - if poll_count < MAX_POLL_COUNT: - poll_stt_batch_status.apply_async( - kwargs={ - "evaluation_run_id": evaluation_run_id, - "org_id": org_id, - "project_id": project_id, - "trace_id": trace_id, - "batch_jobs": batch_jobs, - "poll_count": poll_count + 1, - }, - countdown=60, # Wait longer on error - ) - - return {"success": False, "error": str(e)} - - -@celery_app.task(bind=True, queue="low_priority") -def process_stt_batch_results( - self, - evaluation_run_id: int, - org_id: int, - project_id: int, - trace_id: str, - batch_jobs: dict[str, str], - **kwargs, -) -> dict[str, Any]: - """Process results from completed Gemini batch jobs using GeminiBatchProvider. - - Args: - evaluation_run_id: ID of the evaluation run - org_id: Organization ID - project_id: Project ID - trace_id: Correlation/trace ID - batch_jobs: Dict of provider -> batch_id - - Returns: - dict: Processing result - """ - correlation_id.set(trace_id) - - logger.info( - f"[process_stt_batch_results] Processing batch results | " - f"run_id: {evaluation_run_id}" - ) - - try: - with Session(engine) as session: - # Get the run to access sample mapping - run = get_stt_run_by_id( - session=session, - run_id=evaluation_run_id, - org_id=org_id, - project_id=project_id, - ) - - if not run or not run.score: - logger.error( - f"[process_stt_batch_results] Run or score not found | " - f"run_id: {evaluation_run_id}" - ) - return {"success": False, "error": "Run data not found"} - - sample_file_mapping = run.score.get("sample_file_mapping", []) - sample_ids = [item["sample_id"] for item in sample_file_mapping] - - # Initialize Gemini client and providers - gemini_client = GeminiClient.from_credentials( - session=session, - org_id=org_id, - project_id=project_id, - ) - batch_provider = GeminiBatchProvider(client=gemini_client.client) - files_manager = GeminiFilesManager(gemini_client.client) - - processed_count = 0 - failed_count = 0 - - for provider, batch_id in batch_jobs.items(): - try: - # Use GeminiBatchProvider to download results - results = batch_provider.download_batch_results(batch_id) - - logger.info( - f"[process_stt_batch_results] Got batch results | " - f"provider: {provider}, result_count: {len(results)}" - ) - - # Match results to samples by index - for batch_result in results: - custom_id = batch_result["custom_id"] - # custom_id is the index as string - try: - index = int(custom_id) - except (ValueError, TypeError): - index = results.index(batch_result) - - if index >= len(sample_ids): - continue - - sample_id = sample_ids[index] - - # Find result record for this sample and provider - stmt = select(STTResult).where( - STTResult.evaluation_run_id == evaluation_run_id, - STTResult.stt_sample_id == sample_id, - STTResult.provider == provider, - ) - result_record = session.exec(stmt).one_or_none() - - if result_record: - if batch_result.get("response"): - text = batch_result["response"].get("text", "") - update_stt_result( - session=session, - result_id=result_record.id, - transcription=text, - status=STTResultStatus.COMPLETED.value, - ) - processed_count += 1 - else: - update_stt_result( - session=session, - result_id=result_record.id, - status=STTResultStatus.FAILED.value, - error_message=batch_result.get( - "error", "Unknown error" - ), - ) - failed_count += 1 - - except Exception as e: - logger.error( - f"[process_stt_batch_results] Failed to process provider results | " - f"provider: {provider}, error: {str(e)}" - ) - failed_count += len(sample_file_mapping) - - # Clean up Google Files - for item in sample_file_mapping: - file_uri = item.get("file_uri") - if file_uri: - try: - files_manager.delete_file(file_uri) - except Exception as e: - logger.warning( - f"[process_stt_batch_results] Failed to delete Google file | " - f"file_uri: {file_uri}, error: {str(e)}" - ) - - # Update run status - status_counts = count_results_by_status( - session=session, run_id=evaluation_run_id - ) - - completed = status_counts.get(STTResultStatus.COMPLETED.value, 0) - failed = status_counts.get(STTResultStatus.FAILED.value, 0) - pending = status_counts.get(STTResultStatus.PENDING.value, 0) - - final_status = "completed" if pending == 0 else "processing" - error_message = None - if failed > 0: - error_message = f"{failed} transcription(s) failed" - - update_stt_run( - session=session, - run_id=evaluation_run_id, - status=final_status, - processed_samples=completed + failed, - error_message=error_message, - ) - - logger.info( - f"[process_stt_batch_results] Batch results processed | " - f"run_id: {evaluation_run_id}, completed: {completed}, " - f"failed: {failed}, status: {final_status}" - ) - - return { - "success": True, - "run_id": evaluation_run_id, - "completed": completed, - "failed": failed, - "status": final_status, - } - - except Exception as e: - logger.error( - f"[process_stt_batch_results] Failed to process results | " - f"run_id: {evaluation_run_id}, error: {str(e)}", - exc_info=True, - ) - - with Session(engine) as session: - update_stt_run( - session=session, - run_id=evaluation_run_id, - status="failed", - error_message=f"Result processing failed: {str(e)}", - ) - - return {"success": False, "error": str(e)} diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py index ca6bd2af2..72b22ec69 100644 --- a/backend/app/crud/evaluations/cron.py +++ b/backend/app/crud/evaluations/cron.py @@ -12,6 +12,7 @@ from sqlmodel import Session, select from app.crud.evaluations.processing import poll_all_pending_evaluations +from app.crud.stt_evaluations.cron import poll_all_pending_stt_evaluations from app.models import Organization logger = logging.getLogger(__name__) @@ -84,22 +85,38 @@ async def process_all_pending_evaluations(session: Session) -> dict[str, Any]: f"[process_all_pending_evaluations] Processing org_id={org.id} ({org.name})" ) - # Poll all pending evaluations for this org + # Poll all pending text evaluations for this org summary = await poll_all_pending_evaluations( session=session, org_id=org.id ) + # Poll all pending STT evaluations for this org + stt_summary = await poll_all_pending_stt_evaluations( + session=session, org_id=org.id + ) + + # Merge summaries + combined_summary = { + "text": summary, + "stt": stt_summary, + "processed": summary.get("processed", 0) + + stt_summary.get("processed", 0), + "failed": summary.get("failed", 0) + stt_summary.get("failed", 0), + "still_processing": summary.get("still_processing", 0) + + stt_summary.get("still_processing", 0), + } + results.append( { "org_id": org.id, "org_name": org.name, - "summary": summary, + "summary": combined_summary, } ) - total_processed += summary.get("processed", 0) - total_failed += summary.get("failed", 0) - total_still_processing += summary.get("still_processing", 0) + total_processed += combined_summary["processed"] + total_failed += combined_summary["failed"] + total_still_processing += combined_summary["still_processing"] except Exception as e: logger.error( diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py index 24ee863f0..196b8aec2 100644 --- a/backend/app/crud/stt_evaluations/__init__.py +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -1,5 +1,7 @@ """STT Evaluation CRUD operations.""" +from .batch import start_stt_evaluation_batch +from .cron import poll_all_pending_stt_evaluations from .dataset import ( create_stt_dataset, create_stt_samples, @@ -24,6 +26,10 @@ ) __all__ = [ + # Batch + "start_stt_evaluation_batch", + # Cron + "poll_all_pending_stt_evaluations", # Dataset "create_stt_dataset", "create_stt_samples", diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py new file mode 100644 index 000000000..078fa4e44 --- /dev/null +++ b/backend/app/crud/stt_evaluations/batch.py @@ -0,0 +1,255 @@ +"""Batch submission functions for STT evaluation processing.""" + +import logging +from typing import Any + +from sqlmodel import Session + +from app.core.batch import GeminiBatchProvider +from app.core.cloud.storage import get_cloud_storage +from app.crud.stt_evaluations.result import update_stt_result +from app.crud.stt_evaluations.run import update_stt_run +from app.models import EvaluationRun +from app.models.stt_evaluation import STTResultStatus, STTSample +from app.services.stt_evaluations.gemini import GeminiClient, GeminiFilesManager + +logger = logging.getLogger(__name__) + +# Default transcription prompt +DEFAULT_TRANSCRIPTION_PROMPT = ( + "Generate a verbatim transcript of the speech in this audio file. " + "Return only the transcription text without any formatting, timestamps, or metadata." +) + +# Provider name to Gemini model mapping +PROVIDER_MODEL_MAPPING: dict[str, str] = { + "gemini-2.5-pro": "models/gemini-2.5-pro", + "gemini-2.5-flash": "models/gemini-2.5-flash", + "gemini-2.0-flash": "models/gemini-2.0-flash", +} + + +def _get_model_for_provider(provider: str) -> str: + """Map provider name to Gemini model. + + Args: + provider: Provider name + + Returns: + str: Gemini model name + """ + return PROVIDER_MODEL_MAPPING.get(provider, f"models/{provider}") + + +def _build_batch_requests( + sample_file_mapping: list[tuple[int, int | None, str]], + prompt: str = DEFAULT_TRANSCRIPTION_PROMPT, +) -> list[dict[str, Any]]: + """Build JSONL batch request data from sample-file mappings. + + Each request follows the Gemini GenerateContentRequest format + with a text prompt and file_data reference. + + Args: + sample_file_mapping: List of (sample_id, result_id, google_file_uri) tuples + prompt: Transcription prompt + + Returns: + list[dict]: JSONL-compatible request dicts for GeminiBatchProvider + """ + return [ + { + "contents": [ + { + "parts": [ + {"text": prompt}, + {"file_data": {"file_uri": file_uri}}, + ], + "role": "user", + } + ], + } + for _, _, file_uri in sample_file_mapping + ] + + +def start_stt_evaluation_batch( + *, + session: Session, + run: EvaluationRun, + samples: list[STTSample], + result_refs: list[dict[str, Any]], + org_id: int, + project_id: int, +) -> dict[str, Any]: + """Upload audio files to Google and submit Gemini batch jobs. + + This function runs synchronously during the API request: + 1. Initializes GeminiClient + 2. Uploads audio files to Google Files API + 3. Builds batch requests + 4. Submits batch jobs per provider + 5. Stores batch_jobs and sample_file_mapping in run.score + 6. Updates run status to "processing" + + Args: + session: Database session + run: The evaluation run record + samples: List of STT samples to process + result_refs: List of result reference dicts with id, stt_sample_id, provider + org_id: Organization ID + project_id: Project ID + + Returns: + dict: Result with batch job information + + Raises: + Exception: If all batch submissions fail + """ + logger.info( + f"[start_stt_evaluation_batch] Starting batch submission | " + f"run_id: {run.id}, sample_count: {len(samples)}" + ) + + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=project_id, + ) + + # Get cloud storage for S3 access + storage = get_cloud_storage(session=session, project_id=project_id) + + # Upload audio files to Google Files API + files_manager = GeminiFilesManager(gemini_client.client) + + sample_file_mapping: list[tuple[int, int | None, str]] = [] + + for sample in samples: + try: + # Get signed URL for S3 audio file + signed_url = storage.get_signed_url( + sample.object_store_url, expires_in=3600 + ) + + # Extract filename from URL + filename = sample.object_store_url.split("/")[-1] + + # Upload to Google Files API + google_file_uri = files_manager.upload_from_url( + signed_url=signed_url, + filename=filename, + ) + + # Find the result record for this sample + result_for_sample = next( + (r for r in result_refs if r["stt_sample_id"] == sample.id), + None, + ) + + sample_file_mapping.append( + ( + sample.id, + result_for_sample["id"] if result_for_sample else None, + google_file_uri, + ) + ) + + logger.info( + f"[start_stt_evaluation_batch] Uploaded audio to Google | " + f"sample_id: {sample.id}, file_uri: {google_file_uri}" + ) + + except Exception as e: + logger.error( + f"[start_stt_evaluation_batch] Failed to upload audio | " + f"sample_id: {sample.id}, error: {str(e)}" + ) + # Mark result as failed + for ref in result_refs: + if ref["stt_sample_id"] == sample.id: + update_stt_result( + session=session, + result_id=ref["id"], + status=STTResultStatus.FAILED.value, + error_message=f"Failed to upload audio: {str(e)}", + ) + + if not sample_file_mapping: + raise Exception("Failed to upload any audio files") + + # Build batch requests from uploaded files + jsonl_data = _build_batch_requests(sample_file_mapping) + + # Process each provider using GeminiBatchProvider + providers = run.providers or ["gemini-2.5-pro"] + batch_jobs: dict[str, str] = {} + + for provider in providers: + try: + model = _get_model_for_provider(provider) + batch_provider = GeminiBatchProvider( + client=gemini_client.client, model=model + ) + + batch_result = batch_provider.create_batch( + jsonl_data=jsonl_data, + config={ + "display_name": f"stt-eval-{run.id}-{provider}", + "model": model, + }, + ) + + batch_jobs[provider] = batch_result["provider_batch_id"] + + logger.info( + f"[start_stt_evaluation_batch] Batch job submitted | " + f"run_id: {run.id}, provider: {provider}, " + f"batch_id: {batch_result['provider_batch_id']}" + ) + + except Exception as e: + logger.error( + f"[start_stt_evaluation_batch] Failed to submit batch | " + f"provider: {provider}, error: {str(e)}" + ) + # Update results for this provider as failed + for ref in result_refs: + if ref["provider"] == provider: + update_stt_result( + session=session, + result_id=ref["id"], + status=STTResultStatus.FAILED.value, + error_message=f"Batch submission failed: {str(e)}", + ) + + if not batch_jobs: + raise Exception("All batch submissions failed") + + # Store batch job info in run score for polling + update_stt_run( + session=session, + run_id=run.id, + status="processing", + score={ + "batch_jobs": batch_jobs, + "sample_file_mapping": [ + {"sample_id": s, "result_id": r, "file_uri": f} + for s, r, f in sample_file_mapping + ], + }, + ) + + logger.info( + f"[start_stt_evaluation_batch] Batch submission complete | " + f"run_id: {run.id}, batch_jobs: {list(batch_jobs.keys())}, " + f"sample_count: {len(sample_file_mapping)}" + ) + + return { + "success": True, + "run_id": run.id, + "batch_jobs": batch_jobs, + "sample_count": len(sample_file_mapping), + } diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py new file mode 100644 index 000000000..c3911c6a6 --- /dev/null +++ b/backend/app/crud/stt_evaluations/cron.py @@ -0,0 +1,345 @@ +"""Cron processing functions for STT evaluations. + +This module provides functions that are called periodically to process +pending STT evaluations - polling batch status and processing completed batches. +""" + +import logging +from typing import Any + +from sqlmodel import Session, select + +from app.core.batch import BatchJobState, GeminiBatchProvider +from app.crud.stt_evaluations.result import count_results_by_status, update_stt_result +from app.crud.stt_evaluations.run import get_pending_stt_runs, update_stt_run +from app.models import EvaluationRun +from app.models.stt_evaluation import STTResult, STTResultStatus +from app.services.stt_evaluations.gemini import GeminiClient, GeminiFilesManager + +logger = logging.getLogger(__name__) + +# Terminal states that indicate batch processing is complete +TERMINAL_STATES = { + BatchJobState.SUCCEEDED.value, + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} + + +async def poll_all_pending_stt_evaluations( + session: Session, + org_id: int, +) -> dict[str, Any]: + """Poll all pending STT evaluations for an organization. + + This function: + 1. Gets all STT runs in "processing" status + 2. For each run, polls the Gemini batch status + 3. If completed, processes the results + 4. Returns a summary of what was processed + + Args: + session: Database session + org_id: Organization ID + + Returns: + dict: Summary with processed, failed, still_processing counts + """ + logger.info( + f"[poll_all_pending_stt_evaluations] Starting STT evaluation polling | " + f"org_id: {org_id}" + ) + + # Get all pending STT runs for this organization + pending_runs = get_pending_stt_runs(session=session, org_id=org_id) + + if not pending_runs: + logger.info( + f"[poll_all_pending_stt_evaluations] No pending STT runs | org_id: {org_id}" + ) + return {"processed": 0, "failed": 0, "still_processing": 0} + + logger.info( + f"[poll_all_pending_stt_evaluations] Found {len(pending_runs)} pending STT runs | " + f"org_id: {org_id}" + ) + + processed = 0 + failed = 0 + still_processing = 0 + + for run in pending_runs: + try: + result = await poll_stt_run(session=session, run=run, org_id=org_id) + + if result["status"] == "completed": + processed += 1 + elif result["status"] == "failed": + failed += 1 + else: # still_processing + still_processing += 1 + + except Exception as e: + logger.error( + f"[poll_all_pending_stt_evaluations] Error polling run | " + f"run_id: {run.id}, error: {str(e)}", + exc_info=True, + ) + failed += 1 + + logger.info( + f"[poll_all_pending_stt_evaluations] Polling complete | " + f"org_id: {org_id}, processed: {processed}, failed: {failed}, " + f"still_processing: {still_processing}" + ) + + return { + "processed": processed, + "failed": failed, + "still_processing": still_processing, + } + + +async def poll_stt_run( + session: Session, + run: EvaluationRun, + org_id: int, +) -> dict[str, Any]: + """Poll a single STT evaluation run's batch status. + + Args: + session: Database session + run: The evaluation run to poll + org_id: Organization ID + + Returns: + dict: Status result with "status" key (completed/failed/still_processing) + """ + logger.info(f"[poll_stt_run] Polling run | run_id: {run.id}") + + # Check if run has batch jobs stored + if not run.score or "batch_jobs" not in run.score: + logger.warning( + f"[poll_stt_run] Run has no batch_jobs in score | run_id: {run.id}" + ) + return {"status": "failed", "error": "No batch jobs found"} + + batch_jobs: dict[str, str] = run.score.get("batch_jobs", {}) + + if not batch_jobs: + logger.warning(f"[poll_stt_run] Empty batch_jobs | run_id: {run.id}") + return {"status": "failed", "error": "Empty batch jobs"} + + try: + # Initialize Gemini client + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=run.project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + + all_complete = True + any_success = False + + for provider, batch_id in batch_jobs.items(): + status = batch_provider.get_batch_status(batch_id) + provider_status = status["provider_status"] + + logger.info( + f"[poll_stt_run] Batch status | " + f"run_id: {run.id}, provider: {provider}, " + f"batch_id: {batch_id}, state: {provider_status}" + ) + + is_terminal = provider_status in TERMINAL_STATES + + if not is_terminal: + all_complete = False + elif provider_status == BatchJobState.SUCCEEDED.value: + any_success = True + + if not all_complete: + return {"status": "still_processing"} + + # All batches complete - process results + if any_success: + await process_completed_stt_batch( + session=session, + run=run, + batch_jobs=batch_jobs, + org_id=org_id, + ) + return {"status": "completed"} + else: + update_stt_run( + session=session, + run_id=run.id, + status="failed", + error_message="All batch jobs failed", + ) + return {"status": "failed", "error": "All batch jobs failed"} + + except Exception as e: + logger.error( + f"[poll_stt_run] Error polling run | run_id: {run.id}, error: {str(e)}", + exc_info=True, + ) + return {"status": "failed", "error": str(e)} + + +async def process_completed_stt_batch( + session: Session, + run: EvaluationRun, + batch_jobs: dict[str, str], + org_id: int, +) -> None: + """Process completed Gemini batch - download results and update STT result records. + + Args: + session: Database session + run: The evaluation run + batch_jobs: Dict of provider -> batch_id + org_id: Organization ID + """ + logger.info( + f"[process_completed_stt_batch] Processing batch results | run_id: {run.id}" + ) + + sample_file_mapping = run.score.get("sample_file_mapping", []) + sample_ids = [item["sample_id"] for item in sample_file_mapping] + + # Initialize Gemini client and providers + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=org_id, + project_id=run.project_id, + ) + batch_provider = GeminiBatchProvider(client=gemini_client.client) + files_manager = GeminiFilesManager(gemini_client.client) + + processed_count = 0 + failed_count = 0 + + for provider, batch_id in batch_jobs.items(): + try: + # Check if batch succeeded before downloading + status = batch_provider.get_batch_status(batch_id) + if status["provider_status"] != BatchJobState.SUCCEEDED.value: + logger.warning( + f"[process_completed_stt_batch] Batch not succeeded | " + f"provider: {provider}, status: {status['provider_status']}" + ) + # Mark results for this provider as failed + for item in sample_file_mapping: + sample_id = item["sample_id"] + stmt = select(STTResult).where( + STTResult.evaluation_run_id == run.id, + STTResult.stt_sample_id == sample_id, + STTResult.provider == provider, + ) + result_record = session.exec(stmt).one_or_none() + if result_record: + update_stt_result( + session=session, + result_id=result_record.id, + status=STTResultStatus.FAILED.value, + error_message=f"Batch {status['provider_status']}", + ) + failed_count += 1 + continue + + # Download results using GeminiBatchProvider + results = batch_provider.download_batch_results(batch_id) + + logger.info( + f"[process_completed_stt_batch] Got batch results | " + f"provider: {provider}, result_count: {len(results)}" + ) + + # Match results to samples by index + for batch_result in results: + custom_id = batch_result["custom_id"] + # custom_id is the index as string + try: + index = int(custom_id) + except (ValueError, TypeError): + index = results.index(batch_result) + + if index >= len(sample_ids): + continue + + sample_id = sample_ids[index] + + # Find result record for this sample and provider + stmt = select(STTResult).where( + STTResult.evaluation_run_id == run.id, + STTResult.stt_sample_id == sample_id, + STTResult.provider == provider, + ) + result_record = session.exec(stmt).one_or_none() + + if result_record: + if batch_result.get("response"): + text = batch_result["response"].get("text", "") + update_stt_result( + session=session, + result_id=result_record.id, + transcription=text, + status=STTResultStatus.COMPLETED.value, + ) + processed_count += 1 + else: + update_stt_result( + session=session, + result_id=result_record.id, + status=STTResultStatus.FAILED.value, + error_message=batch_result.get("error", "Unknown error"), + ) + failed_count += 1 + + except Exception as e: + logger.error( + f"[process_completed_stt_batch] Failed to process provider results | " + f"provider: {provider}, error: {str(e)}" + ) + failed_count += len(sample_file_mapping) + + # Clean up Google Files + for item in sample_file_mapping: + file_uri = item.get("file_uri") + if file_uri: + try: + files_manager.delete_file(file_uri) + except Exception as e: + logger.warning( + f"[process_completed_stt_batch] Failed to delete Google file | " + f"file_uri: {file_uri}, error: {str(e)}" + ) + + # Update run status + status_counts = count_results_by_status(session=session, run_id=run.id) + + completed = status_counts.get(STTResultStatus.COMPLETED.value, 0) + failed = status_counts.get(STTResultStatus.FAILED.value, 0) + pending = status_counts.get(STTResultStatus.PENDING.value, 0) + + final_status = "completed" if pending == 0 else "processing" + error_message = None + if failed > 0: + error_message = f"{failed} transcription(s) failed" + + update_stt_run( + session=session, + run_id=run.id, + status=final_status, + processed_samples=completed + failed, + error_message=error_message, + ) + + logger.info( + f"[process_completed_stt_batch] Batch results processed | " + f"run_id: {run.id}, completed: {completed}, " + f"failed: {failed}, status: {final_status}" + ) From 2295da55fdbdcc9ed4b1ed70ad27ec3739fae4f2 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Mon, 2 Feb 2026 20:54:42 +0530 Subject: [PATCH 11/36] added gemini as provider --- backend/app/core/providers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index dfaae233a..01769e8c8 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -12,6 +12,7 @@ class Provider(str, Enum): OPENAI = "openai" AWS = "aws" LANGFUSE = "langfuse" + GEMINI = "gemini" @dataclass @@ -30,6 +31,7 @@ class ProviderConfig: Provider.LANGFUSE: ProviderConfig( required_fields=["secret_key", "public_key", "host"] ), + Provider.GEMINI: ProviderConfig(required_fields=["api_key"]), } From 25e60027d6988866176584c96f6e9d5446d19521 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Mon, 2 Feb 2026 20:55:00 +0530 Subject: [PATCH 12/36] moving to batch job in gemini --- backend/app/core/batch/__init__.py | 3 +- backend/app/core/batch/gemini.py | 67 +++++ backend/app/core/storage_utils.py | 22 +- backend/app/crud/evaluations/processing.py | 2 + backend/app/crud/stt_evaluations/batch.py | 210 +++++++--------- backend/app/crud/stt_evaluations/cron.py | 229 ++++++++---------- backend/app/crud/stt_evaluations/run.py | 15 +- .../stt_evaluations/gemini/__init__.py | 3 +- 8 files changed, 292 insertions(+), 259 deletions(-) diff --git a/backend/app/core/batch/__init__.py b/backend/app/core/batch/__init__.py index 25a33b6d7..382fcc4ff 100644 --- a/backend/app/core/batch/__init__.py +++ b/backend/app/core/batch/__init__.py @@ -1,7 +1,7 @@ """Batch processing infrastructure for LLM providers.""" from .base import BatchProvider -from .gemini import BatchJobState, GeminiBatchProvider +from .gemini import BatchJobState, GeminiBatchProvider, create_stt_batch_requests from .openai import OpenAIBatchProvider from .operations import ( download_batch_results, @@ -16,6 +16,7 @@ "BatchJobState", "GeminiBatchProvider", "OpenAIBatchProvider", + "create_stt_batch_requests", "start_batch_job", "download_batch_results", "process_completed_batch", diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 45d751c87..6c4d56b49 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -11,6 +11,8 @@ from google import genai from google.genai import types +from app.core.storage_utils import get_mime_from_url + from .base import BatchProvider logger = logging.getLogger(__name__) @@ -346,3 +348,68 @@ def _extract_text_from_response(response: Any) -> str: if hasattr(part, "text"): text += part.text return text + + +def create_stt_batch_requests( + signed_urls: list[str], + prompt: str, + keys: list[str] | None = None, +) -> list[dict[str, Any]]: + """ + Create batch API requests for Gemini STT using signed URLs. + + This function generates request payloads suitable for Gemini's batch API + using signed URLs directly (no file upload required). MIME types are + automatically detected from the URL path. + + Args: + signed_urls: List of signed URLs pointing to audio files + prompt: Transcription prompt/instructions for the model + keys: Optional list of custom IDs for each request. If not provided, + uses 0-indexed integers as strings. + + Returns: + List of batch request dictionaries ready for GeminiBatchProvider.create_batch() + + Example: + >>> urls = ["https://bucket.s3.amazonaws.com/audio.mp3?..."] + >>> prompt = "Transcribe this audio file." + >>> requests = create_stt_batch_requests(urls, prompt) + >>> provider.create_batch(requests, {"display_name": "stt-batch"}) + """ + if keys is not None and len(keys) != len(signed_urls): + raise ValueError( + f"Length of keys ({len(keys)}) must match signed_urls ({len(signed_urls)})" + ) + + requests = [] + for i, url in enumerate(signed_urls): + mime_type = get_mime_from_url(url) + if mime_type is None: + logger.warning( + f"[create_stt_batch_requests] Could not determine MIME type for URL | " + f"index={i} | defaulting to audio/mpeg" + ) + mime_type = "audio/mpeg" + + request = { + "contents": [ + { + "parts": [ + {"text": prompt}, + {"file_data": {"mime_type": mime_type, "file_uri": url}}, + ], + "role": "user", + } + ] + } + + # Add key if provided for tracking + if keys is not None: + request["key"] = keys[i] + + requests.append(request) + + logger.info(f"[create_stt_batch_requests] Created {len(requests)} batch requests") + + return requests diff --git a/backend/app/core/storage_utils.py b/backend/app/core/storage_utils.py index 6130ca881..aa1b72199 100644 --- a/backend/app/core/storage_utils.py +++ b/backend/app/core/storage_utils.py @@ -5,12 +5,13 @@ to cloud object storage, abstracting away provider-specific details. """ -import io import json import logging +import mimetypes from datetime import datetime from io import BytesIO from pathlib import Path +from urllib.parse import unquote, urlparse from starlette.datastructures import Headers, UploadFile @@ -19,6 +20,25 @@ logger = logging.getLogger(__name__) +def get_mime_from_url(url: str) -> str | None: + """ + Extract MIME type from a URL by parsing its path component. + + Works with signed URLs by ignoring query parameters and extracting + the file extension from the path. + + Args: + url: URL string (can include query parameters like signed URLs) + + Returns: + MIME type string (e.g., 'audio/mpeg') or None if unable to determine + """ + parsed = urlparse(url) + path = unquote(parsed.path) + mime_type, _ = mimetypes.guess_type(path) + return mime_type + + def upload_to_object_store( storage: CloudStorage, content: bytes, diff --git a/backend/app/crud/evaluations/processing.py b/backend/app/crud/evaluations/processing.py index 99ce39d56..05ca94045 100644 --- a/backend/app/crud/evaluations/processing.py +++ b/backend/app/crud/evaluations/processing.py @@ -664,9 +664,11 @@ async def poll_all_pending_evaluations(session: Session, org_id: int) -> dict[st } """ # Get pending evaluations (status = "processing") + # Filter to only text evaluations - STT/TTS evaluations have their own polling statement = select(EvaluationRun).where( EvaluationRun.status == "processing", EvaluationRun.organization_id == org_id, + EvaluationRun.type == "text", ) pending_runs = session.exec(statement).all() diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index 078fa4e44..e44009987 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -5,13 +5,17 @@ from sqlmodel import Session -from app.core.batch import GeminiBatchProvider +from app.core.batch import ( + GeminiBatchProvider, + create_stt_batch_requests, + start_batch_job, +) from app.core.cloud.storage import get_cloud_storage from app.crud.stt_evaluations.result import update_stt_result from app.crud.stt_evaluations.run import update_stt_run from app.models import EvaluationRun from app.models.stt_evaluation import STTResultStatus, STTSample -from app.services.stt_evaluations.gemini import GeminiClient, GeminiFilesManager +from app.services.stt_evaluations.gemini import GeminiClient logger = logging.getLogger(__name__) @@ -41,38 +45,6 @@ def _get_model_for_provider(provider: str) -> str: return PROVIDER_MODEL_MAPPING.get(provider, f"models/{provider}") -def _build_batch_requests( - sample_file_mapping: list[tuple[int, int | None, str]], - prompt: str = DEFAULT_TRANSCRIPTION_PROMPT, -) -> list[dict[str, Any]]: - """Build JSONL batch request data from sample-file mappings. - - Each request follows the Gemini GenerateContentRequest format - with a text prompt and file_data reference. - - Args: - sample_file_mapping: List of (sample_id, result_id, google_file_uri) tuples - prompt: Transcription prompt - - Returns: - list[dict]: JSONL-compatible request dicts for GeminiBatchProvider - """ - return [ - { - "contents": [ - { - "parts": [ - {"text": prompt}, - {"file_data": {"file_uri": file_uri}}, - ], - "role": "user", - } - ], - } - for _, _, file_uri in sample_file_mapping - ] - - def start_stt_evaluation_batch( *, session: Session, @@ -81,16 +53,18 @@ def start_stt_evaluation_batch( result_refs: list[dict[str, Any]], org_id: int, project_id: int, + signed_url_expires_in: int = 86400, ) -> dict[str, Any]: - """Upload audio files to Google and submit Gemini batch jobs. + """Generate signed URLs and submit Gemini batch job for STT evaluation. This function runs synchronously during the API request: 1. Initializes GeminiClient - 2. Uploads audio files to Google Files API - 3. Builds batch requests - 4. Submits batch jobs per provider - 5. Stores batch_jobs and sample_file_mapping in run.score - 6. Updates run status to "processing" + 2. Generates signed URLs for audio files (valid for batch processing window) + 3. Builds batch requests using signed URLs directly + 4. Submits batch job via start_batch_job (creates BatchJob record) + 5. Links batch_job_id to the evaluation run + 6. Stores sample_file_mapping in run.score + 7. Updates run status to "processing" Args: session: Database session @@ -99,12 +73,13 @@ def start_stt_evaluation_batch( result_refs: List of result reference dicts with id, stt_sample_id, provider org_id: Organization ID project_id: Project ID + signed_url_expires_in: Signed URL expiry in seconds (default: 24 hours for batch) Returns: dict: Result with batch job information Raises: - Exception: If all batch submissions fail + Exception: If batch submission fails """ logger.info( f"[start_stt_evaluation_batch] Starting batch submission | " @@ -121,25 +96,16 @@ def start_stt_evaluation_batch( # Get cloud storage for S3 access storage = get_cloud_storage(session=session, project_id=project_id) - # Upload audio files to Google Files API - files_manager = GeminiFilesManager(gemini_client.client) - - sample_file_mapping: list[tuple[int, int | None, str]] = [] + # Generate signed URLs for audio files + sample_url_mapping: list[dict[str, Any]] = [] + signed_urls: list[str] = [] for sample in samples: try: # Get signed URL for S3 audio file + # Use longer expiry for batch processing (up to 24 hours) signed_url = storage.get_signed_url( - sample.object_store_url, expires_in=3600 - ) - - # Extract filename from URL - filename = sample.object_store_url.split("/")[-1] - - # Upload to Google Files API - google_file_uri = files_manager.upload_from_url( - signed_url=signed_url, - filename=filename, + sample.object_store_url, expires_in=signed_url_expires_in ) # Find the result record for this sample @@ -148,22 +114,23 @@ def start_stt_evaluation_batch( None, ) - sample_file_mapping.append( - ( - sample.id, - result_for_sample["id"] if result_for_sample else None, - google_file_uri, - ) + sample_url_mapping.append( + { + "sample_id": sample.id, + "result_id": result_for_sample["id"] if result_for_sample else None, + "signed_url": signed_url, + } ) + signed_urls.append(signed_url) logger.info( - f"[start_stt_evaluation_batch] Uploaded audio to Google | " - f"sample_id: {sample.id}, file_uri: {google_file_uri}" + f"[start_stt_evaluation_batch] Generated signed URL | " + f"sample_id: {sample.id}" ) except Exception as e: logger.error( - f"[start_stt_evaluation_batch] Failed to upload audio | " + f"[start_stt_evaluation_batch] Failed to generate signed URL | " f"sample_id: {sample.id}, error: {str(e)}" ) # Mark result as failed @@ -173,83 +140,84 @@ def start_stt_evaluation_batch( session=session, result_id=ref["id"], status=STTResultStatus.FAILED.value, - error_message=f"Failed to upload audio: {str(e)}", + error_message=f"Failed to generate signed URL: {str(e)}", ) - if not sample_file_mapping: - raise Exception("Failed to upload any audio files") + if not sample_url_mapping: + raise Exception("Failed to generate signed URLs for any audio files") - # Build batch requests from uploaded files - jsonl_data = _build_batch_requests(sample_file_mapping) + # Build batch requests using signed URLs directly (with mime type detection) + jsonl_data = create_stt_batch_requests( + signed_urls=signed_urls, + prompt=DEFAULT_TRANSCRIPTION_PROMPT, + ) - # Process each provider using GeminiBatchProvider + # Use first provider (STT evaluations use one provider per run) providers = run.providers or ["gemini-2.5-pro"] - batch_jobs: dict[str, str] = {} - - for provider in providers: - try: - model = _get_model_for_provider(provider) - batch_provider = GeminiBatchProvider( - client=gemini_client.client, model=model - ) - - batch_result = batch_provider.create_batch( - jsonl_data=jsonl_data, - config={ - "display_name": f"stt-eval-{run.id}-{provider}", - "model": model, - }, - ) - - batch_jobs[provider] = batch_result["provider_batch_id"] - - logger.info( - f"[start_stt_evaluation_batch] Batch job submitted | " - f"run_id: {run.id}, provider: {provider}, " - f"batch_id: {batch_result['provider_batch_id']}" - ) - - except Exception as e: - logger.error( - f"[start_stt_evaluation_batch] Failed to submit batch | " - f"provider: {provider}, error: {str(e)}" + provider = providers[0] + model = _get_model_for_provider(provider) + + # Create batch job using the standard batch operations + batch_provider = GeminiBatchProvider(client=gemini_client.client, model=model) + + try: + batch_job = start_batch_job( + session=session, + provider=batch_provider, + provider_name="gemini", + job_type="stt_evaluation", + organization_id=org_id, + project_id=project_id, + jsonl_data=jsonl_data, + config={ + "display_name": f"stt-eval-{run.id}-{provider}", + "model": model, + "stt_provider": provider, + }, + ) + + logger.info( + f"[start_stt_evaluation_batch] Batch job created | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}, " + f"provider_batch_id: {batch_job.provider_batch_id}" + ) + + except Exception as e: + logger.error( + f"[start_stt_evaluation_batch] Failed to submit batch | " + f"provider: {provider}, error: {str(e)}" + ) + # Update all results as failed + for ref in result_refs: + update_stt_result( + session=session, + result_id=ref["id"], + status=STTResultStatus.FAILED.value, + error_message=f"Batch submission failed: {str(e)}", ) - # Update results for this provider as failed - for ref in result_refs: - if ref["provider"] == provider: - update_stt_result( - session=session, - result_id=ref["id"], - status=STTResultStatus.FAILED.value, - error_message=f"Batch submission failed: {str(e)}", - ) - - if not batch_jobs: - raise Exception("All batch submissions failed") + raise Exception(f"Batch submission failed: {str(e)}") - # Store batch job info in run score for polling + # Link batch job to the evaluation run and store sample mapping update_stt_run( session=session, run_id=run.id, status="processing", + batch_job_id=batch_job.id, score={ - "batch_jobs": batch_jobs, - "sample_file_mapping": [ - {"sample_id": s, "result_id": r, "file_uri": f} - for s, r, f in sample_file_mapping - ], + "sample_file_mapping": sample_url_mapping, }, ) logger.info( f"[start_stt_evaluation_batch] Batch submission complete | " - f"run_id: {run.id}, batch_jobs: {list(batch_jobs.keys())}, " - f"sample_count: {len(sample_file_mapping)}" + f"run_id: {run.id}, batch_job_id: {batch_job.id}, " + f"sample_count: {len(sample_url_mapping)}" ) return { "success": True, "run_id": run.id, - "batch_jobs": batch_jobs, - "sample_count": len(sample_file_mapping), + "batch_job_id": batch_job.id, + "provider_batch_id": batch_job.provider_batch_id, + "sample_count": len(sample_url_mapping), } diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index c3911c6a6..a70f1c1e8 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -9,12 +9,13 @@ from sqlmodel import Session, select -from app.core.batch import BatchJobState, GeminiBatchProvider +from app.core.batch import BatchJobState, GeminiBatchProvider, poll_batch_status +from app.crud.job import get_batch_job from app.crud.stt_evaluations.result import count_results_by_status, update_stt_result from app.crud.stt_evaluations.run import get_pending_stt_runs, update_stt_run from app.models import EvaluationRun from app.models.stt_evaluation import STTResult, STTResultStatus -from app.services.stt_evaluations.gemini import GeminiClient, GeminiFilesManager +from app.services.stt_evaluations.gemini import GeminiClient logger = logging.getLogger(__name__) @@ -118,21 +119,22 @@ async def poll_stt_run( """ logger.info(f"[poll_stt_run] Polling run | run_id: {run.id}") - # Check if run has batch jobs stored - if not run.score or "batch_jobs" not in run.score: + # Check if run has batch_job_id + if not run.batch_job_id: + logger.warning(f"[poll_stt_run] Run has no batch_job_id | run_id: {run.id}") + return {"status": "failed", "error": "No batch job found"} + + # Get the batch job record + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if not batch_job: logger.warning( - f"[poll_stt_run] Run has no batch_jobs in score | run_id: {run.id}" + f"[poll_stt_run] BatchJob not found | run_id: {run.id}, " + f"batch_job_id: {run.batch_job_id}" ) - return {"status": "failed", "error": "No batch jobs found"} - - batch_jobs: dict[str, str] = run.score.get("batch_jobs", {}) - - if not batch_jobs: - logger.warning(f"[poll_stt_run] Empty batch_jobs | run_id: {run.id}") - return {"status": "failed", "error": "Empty batch jobs"} + return {"status": "failed", "error": "Batch job not found"} try: - # Initialize Gemini client + # Initialize Gemini client and poll batch status gemini_client = GeminiClient.from_credentials( session=session, org_id=org_id, @@ -140,46 +142,46 @@ async def poll_stt_run( ) batch_provider = GeminiBatchProvider(client=gemini_client.client) - all_complete = True - any_success = False - - for provider, batch_id in batch_jobs.items(): - status = batch_provider.get_batch_status(batch_id) - provider_status = status["provider_status"] - - logger.info( - f"[poll_stt_run] Batch status | " - f"run_id: {run.id}, provider: {provider}, " - f"batch_id: {batch_id}, state: {provider_status}" - ) + # Poll and update batch job status + poll_batch_status( + session=session, + provider=batch_provider, + batch_job=batch_job, + ) - is_terminal = provider_status in TERMINAL_STATES + # Refresh to get updated status + session.refresh(batch_job) + provider_status = batch_job.provider_status - if not is_terminal: - all_complete = False - elif provider_status == BatchJobState.SUCCEEDED.value: - any_success = True + logger.info( + f"[poll_stt_run] Batch status | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}, " + f"state: {provider_status}" + ) - if not all_complete: + # Check if batch is complete + if provider_status not in TERMINAL_STATES: return {"status": "still_processing"} - # All batches complete - process results - if any_success: + # Batch is complete - check if succeeded + if provider_status == BatchJobState.SUCCEEDED.value: await process_completed_stt_batch( session=session, run=run, - batch_jobs=batch_jobs, + batch_job=batch_job, org_id=org_id, ) return {"status": "completed"} else: + # Batch failed + error_msg = batch_job.error_message or f"Batch {provider_status}" update_stt_run( session=session, run_id=run.id, status="failed", - error_message="All batch jobs failed", + error_message=error_msg, ) - return {"status": "failed", "error": "All batch jobs failed"} + return {"status": "failed", "error": error_msg} except Exception as e: logger.error( @@ -192,7 +194,7 @@ async def poll_stt_run( async def process_completed_stt_batch( session: Session, run: EvaluationRun, - batch_jobs: dict[str, str], + batch_job: Any, org_id: int, ) -> None: """Process completed Gemini batch - download results and update STT result records. @@ -200,123 +202,88 @@ async def process_completed_stt_batch( Args: session: Database session run: The evaluation run - batch_jobs: Dict of provider -> batch_id + batch_job: The BatchJob record org_id: Organization ID """ logger.info( - f"[process_completed_stt_batch] Processing batch results | run_id: {run.id}" + f"[process_completed_stt_batch] Processing batch results | " + f"run_id: {run.id}, batch_job_id: {batch_job.id}" ) - sample_file_mapping = run.score.get("sample_file_mapping", []) + sample_file_mapping = run.score.get("sample_file_mapping", []) if run.score else [] sample_ids = [item["sample_id"] for item in sample_file_mapping] - # Initialize Gemini client and providers + # Get the STT provider from batch job config + stt_provider = batch_job.config.get("stt_provider", "gemini-2.5-pro") + + # Initialize Gemini client gemini_client = GeminiClient.from_credentials( session=session, org_id=org_id, project_id=run.project_id, ) batch_provider = GeminiBatchProvider(client=gemini_client.client) - files_manager = GeminiFilesManager(gemini_client.client) processed_count = 0 failed_count = 0 - for provider, batch_id in batch_jobs.items(): - try: - # Check if batch succeeded before downloading - status = batch_provider.get_batch_status(batch_id) - if status["provider_status"] != BatchJobState.SUCCEEDED.value: - logger.warning( - f"[process_completed_stt_batch] Batch not succeeded | " - f"provider: {provider}, status: {status['provider_status']}" - ) - # Mark results for this provider as failed - for item in sample_file_mapping: - sample_id = item["sample_id"] - stmt = select(STTResult).where( - STTResult.evaluation_run_id == run.id, - STTResult.stt_sample_id == sample_id, - STTResult.provider == provider, - ) - result_record = session.exec(stmt).one_or_none() - if result_record: - update_stt_result( - session=session, - result_id=result_record.id, - status=STTResultStatus.FAILED.value, - error_message=f"Batch {status['provider_status']}", - ) - failed_count += 1 - continue + try: + # Download results using GeminiBatchProvider + # Use provider_batch_id to download results + results = batch_provider.download_batch_results(batch_job.provider_batch_id) - # Download results using GeminiBatchProvider - results = batch_provider.download_batch_results(batch_id) + logger.info( + f"[process_completed_stt_batch] Got batch results | " + f"batch_job_id: {batch_job.id}, result_count: {len(results)}" + ) - logger.info( - f"[process_completed_stt_batch] Got batch results | " - f"provider: {provider}, result_count: {len(results)}" - ) + # Match results to samples by index + for batch_result in results: + custom_id = batch_result["custom_id"] + # custom_id is the index as string + try: + index = int(custom_id) + except (ValueError, TypeError): + index = results.index(batch_result) - # Match results to samples by index - for batch_result in results: - custom_id = batch_result["custom_id"] - # custom_id is the index as string - try: - index = int(custom_id) - except (ValueError, TypeError): - index = results.index(batch_result) - - if index >= len(sample_ids): - continue - - sample_id = sample_ids[index] - - # Find result record for this sample and provider - stmt = select(STTResult).where( - STTResult.evaluation_run_id == run.id, - STTResult.stt_sample_id == sample_id, - STTResult.provider == provider, - ) - result_record = session.exec(stmt).one_or_none() - - if result_record: - if batch_result.get("response"): - text = batch_result["response"].get("text", "") - update_stt_result( - session=session, - result_id=result_record.id, - transcription=text, - status=STTResultStatus.COMPLETED.value, - ) - processed_count += 1 - else: - update_stt_result( - session=session, - result_id=result_record.id, - status=STTResultStatus.FAILED.value, - error_message=batch_result.get("error", "Unknown error"), - ) - failed_count += 1 + if index >= len(sample_ids): + continue - except Exception as e: - logger.error( - f"[process_completed_stt_batch] Failed to process provider results | " - f"provider: {provider}, error: {str(e)}" + sample_id = sample_ids[index] + + # Find result record for this sample and provider + stmt = select(STTResult).where( + STTResult.evaluation_run_id == run.id, + STTResult.stt_sample_id == sample_id, + STTResult.provider == stt_provider, ) - failed_count += len(sample_file_mapping) + result_record = session.exec(stmt).one_or_none() + + if result_record: + if batch_result.get("response"): + text = batch_result["response"].get("text", "") + update_stt_result( + session=session, + result_id=result_record.id, + transcription=text, + status=STTResultStatus.COMPLETED.value, + ) + processed_count += 1 + else: + update_stt_result( + session=session, + result_id=result_record.id, + status=STTResultStatus.FAILED.value, + error_message=batch_result.get("error", "Unknown error"), + ) + failed_count += 1 - # Clean up Google Files - for item in sample_file_mapping: - file_uri = item.get("file_uri") - if file_uri: - try: - files_manager.delete_file(file_uri) - except Exception as e: - logger.warning( - f"[process_completed_stt_batch] Failed to delete Google file | " - f"file_uri: {file_uri}, error: {str(e)}" - ) + except Exception as e: + logger.error( + f"[process_completed_stt_batch] Failed to process batch results | " + f"batch_job_id: {batch_job.id}, error: {str(e)}" + ) + failed_count += len(sample_file_mapping) # Update run status status_counts = count_results_by_status(session=session, run_id=run.id) diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py index d67ba07f9..981c99a4f 100644 --- a/backend/app/crud/stt_evaluations/run.py +++ b/backend/app/crud/stt_evaluations/run.py @@ -192,6 +192,7 @@ def update_stt_run( score: dict[str, Any] | None = None, error_message: str | None = None, object_store_url: str | None = None, + batch_job_id: int | None = None, ) -> EvaluationRun | None: """Update an STT evaluation run. @@ -203,6 +204,7 @@ def update_stt_run( score: Score data error_message: Error message object_store_url: URL to stored results + batch_job_id: ID of the associated batch job Returns: EvaluationRun | None: Updated run @@ -228,6 +230,9 @@ def update_stt_run( if object_store_url is not None: run.object_store_url = object_store_url + if batch_job_id is not None: + run.batch_job_id = batch_job_id + run.updated_at = now() session.add(run) @@ -280,18 +285,22 @@ def get_pending_stt_runs( session: Session, org_id: int | None = None, ) -> list[EvaluationRun]: - """Get all pending STT evaluation runs. + """Get all pending STT evaluation runs that are ready for polling. + + Only returns runs with status "processing" that have a batch_job_id set, + meaning the batch has been submitted and is ready to be polled. Args: session: Database session org_id: Optional filter by organization Returns: - list[EvaluationRun]: Pending runs + list[EvaluationRun]: Pending runs ready for polling """ where_clauses = [ EvaluationRun.type == EvaluationType.STT.value, - EvaluationRun.status.in_(["pending", "processing"]), + EvaluationRun.status == "processing", + EvaluationRun.batch_job_id.is_not(None), ] if org_id is not None: diff --git a/backend/app/services/stt_evaluations/gemini/__init__.py b/backend/app/services/stt_evaluations/gemini/__init__.py index a0b35b1b6..07f7037cf 100644 --- a/backend/app/services/stt_evaluations/gemini/__init__.py +++ b/backend/app/services/stt_evaluations/gemini/__init__.py @@ -1,6 +1,5 @@ """Gemini integration for STT evaluation.""" from .client import GeminiClient -from .files import GeminiFilesManager -__all__ = ["GeminiClient", "GeminiFilesManager"] +__all__ = ["GeminiClient"] From db2512e2e619ce38b4f3fd8fd812b1aac9c59e01 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Mon, 2 Feb 2026 22:54:57 +0530 Subject: [PATCH 13/36] code refactoring, using batch requests and files similar to OpenAI --- backend/app/core/batch/gemini.py | 158 +++++++++++++--------- backend/app/crud/stt_evaluations/batch.py | 23 ++-- backend/app/crud/stt_evaluations/cron.py | 27 ++-- 3 files changed, 120 insertions(+), 88 deletions(-) diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 6c4d56b49..e974564e3 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -68,12 +68,12 @@ def __init__(self, client: genai.Client, model: str | None = None): def create_batch( self, jsonl_data: list[dict[str, Any]], config: dict[str, Any] ) -> dict[str, Any]: - """Upload JSONL data and create a batch job with Gemini. + """Upload JSONL file and create a batch job with Gemini. Args: - jsonl_data: List of dictionaries representing JSONL lines. - Each dict should be a valid GenerateContentRequest, e.g.: - {"contents": [{"parts": [{"text": "..."}]}]} + jsonl_data: List of dictionaries in Gemini JSONL format. + Each dict should have the structure: + {"key": "request-1", "request": {"contents": [{"parts": [...]}]}} config: Provider-specific configuration with: - display_name: Optional batch display name - model: Optional model override @@ -81,7 +81,7 @@ def create_batch( Returns: Dictionary containing: - provider_batch_id: Gemini batch job name - - provider_file_id: Uploaded JSONL file name (or None for inline) + - provider_file_id: Uploaded JSONL file name - provider_status: Initial status from Gemini - total_items: Number of items in the batch """ @@ -94,10 +94,20 @@ def create_batch( ) try: - # Use inline requests for the batch + # Create JSONL content + jsonl_content = "\n".join(json.dumps(item) for item in jsonl_data) + + # Upload JSONL file to Gemini File API + uploaded_file = self.upload_file(jsonl_content, purpose="batch") + + logger.info( + f"[create_batch] Uploaded JSONL file | file_name={uploaded_file}" + ) + + # Create batch job using uploaded file batch_job = self._client.batches.create( model=model, - src=jsonl_data, + src=uploaded_file, config={"display_name": display_name}, ) @@ -105,14 +115,14 @@ def create_batch( result = { "provider_batch_id": batch_job.name, - "provider_file_id": None, + "provider_file_id": uploaded_file, "provider_status": initial_state, "total_items": len(jsonl_data), } logger.info( f"[create_batch] Created Gemini batch | batch_id={batch_job.name} | " - f"status={initial_state} | items={len(jsonl_data)}" + f"file_id={uploaded_file} | status={initial_state} | items={len(jsonl_data)}" ) return result @@ -164,19 +174,21 @@ def get_batch_status(self, batch_id: str) -> dict[str, Any]: ) raise - def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: + def download_batch_results( + self, output_file_id: str, request_keys: list[str] | None = None + ) -> list[dict[str, Any]]: """Download and parse batch results from Gemini. - Gemini returns results either as inlined responses or as a - downloadable JSONL file. This method handles both formats and - normalizes the output to match the BatchProvider interface. + Gemini returns results as a downloadable JSONL file where each line + contains the key and response. Args: output_file_id: Gemini batch job name (used to fetch the batch) + request_keys: Deprecated, kept for interface compatibility. Returns: List of result dictionaries, each containing: - - custom_id: Item key from input (or index as string) + - custom_id: Item key from input - response: Dict with "text" key containing the generated text - error: Error info (if item failed), None otherwise """ @@ -194,29 +206,8 @@ def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: results: list[dict[str, Any]] = [] - # Handle inline responses - if batch_job.dest and batch_job.dest.inlined_responses: - for i, response in enumerate(batch_job.dest.inlined_responses): - if response.response: - text = self._extract_text_from_response(response.response) - results.append( - { - "custom_id": str(i), - "response": {"text": text}, - "error": None, - } - ) - elif response.error: - results.append( - { - "custom_id": str(i), - "response": None, - "error": str(response.error), - } - ) - - # Handle file-based results - elif ( + # Handle file-based results (keys are included in each response line) + if ( batch_job.dest and hasattr(batch_job.dest, "file_name") and batch_job.dest.file_name @@ -226,15 +217,27 @@ def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: for i, line in enumerate(lines): try: parsed = json.loads(line) - text = parsed.get("response", {}).get("text", "") custom_id = parsed.get("key", str(i)) - results.append( - { - "custom_id": custom_id, - "response": {"text": text}, - "error": None, - } - ) + + # Extract text from response + response_obj = parsed.get("response") + if response_obj: + text = self._extract_text_from_response_dict(response_obj) + results.append( + { + "custom_id": custom_id, + "response": {"text": text}, + "error": None, + } + ) + elif parsed.get("error"): + results.append( + { + "custom_id": custom_id, + "response": None, + "error": str(parsed["error"]), + } + ) except json.JSONDecodeError as e: logger.error( f"[download_batch_results] Failed to parse JSON | " @@ -256,6 +259,29 @@ def download_batch_results(self, output_file_id: str) -> list[dict[str, Any]]: ) raise + @staticmethod + def _extract_text_from_response_dict(response: dict[str, Any]) -> str: + """Extract text content from a Gemini response dictionary. + + Args: + response: Gemini response as a dictionary + + Returns: + str: Extracted text + """ + # Try direct text field first + if "text" in response: + return response["text"] + + # Extract from candidates structure + text = "" + for candidate in response.get("candidates", []): + content = candidate.get("content", {}) + for part in content.get("parts", []): + if "text" in part: + text += part["text"] + return text + def upload_file(self, content: str, purpose: str = "batch") -> str: """Upload a JSONL file to Gemini Files API. @@ -358,23 +384,23 @@ def create_stt_batch_requests( """ Create batch API requests for Gemini STT using signed URLs. - This function generates request payloads suitable for Gemini's batch API - using signed URLs directly (no file upload required). MIME types are - automatically detected from the URL path. + This function generates request payloads in Gemini's JSONL batch format + using signed URLs directly. MIME types are automatically detected from the URL path. Args: signed_urls: List of signed URLs pointing to audio files prompt: Transcription prompt/instructions for the model - keys: Optional list of custom IDs for each request. If not provided, + keys: Optional list of custom IDs for tracking results. If not provided, uses 0-indexed integers as strings. Returns: - List of batch request dictionaries ready for GeminiBatchProvider.create_batch() + List of batch request dicts in Gemini JSONL format: + {"key": "sample-1", "request": {"contents": [...]}} Example: >>> urls = ["https://bucket.s3.amazonaws.com/audio.mp3?..."] >>> prompt = "Transcribe this audio file." - >>> requests = create_stt_batch_requests(urls, prompt) + >>> requests = create_stt_batch_requests(urls, prompt, keys=["sample-1"]) >>> provider.create_batch(requests, {"display_name": "stt-batch"}) """ if keys is not None and len(keys) != len(signed_urls): @@ -392,22 +418,24 @@ def create_stt_batch_requests( ) mime_type = "audio/mpeg" + # Use provided key or generate from index + key = keys[i] if keys is not None else str(i) + + # Gemini JSONL format: {"key": "...", "request": {...}} request = { - "contents": [ - { - "parts": [ - {"text": prompt}, - {"file_data": {"mime_type": mime_type, "file_uri": url}}, - ], - "role": "user", - } - ] + "key": key, + "request": { + "contents": [ + { + "parts": [ + {"text": prompt}, + {"file_data": {"mime_type": mime_type, "file_uri": url}}, + ], + "role": "user", + } + ] + }, } - - # Add key if provided for tracking - if keys is not None: - request["key"] = keys[i] - requests.append(request) logger.info(f"[create_stt_batch_requests] Created {len(requests)} batch requests") diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index e44009987..968e88ed0 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -60,11 +60,14 @@ def start_stt_evaluation_batch( This function runs synchronously during the API request: 1. Initializes GeminiClient 2. Generates signed URLs for audio files (valid for batch processing window) - 3. Builds batch requests using signed URLs directly - 4. Submits batch job via start_batch_job (creates BatchJob record) + 3. Builds batch requests using signed URLs with sample_id as key for tracking + 4. Submits batch job via start_batch_job (creates BatchJob record with request_keys) 5. Links batch_job_id to the evaluation run - 6. Stores sample_file_mapping in run.score - 7. Updates run status to "processing" + 6. Updates run status to "processing" + + Note: Sample IDs are passed as keys in the batch request and stored in + batch_job.config["request_keys"]. This allows direct mapping of results + without storing sample_file_mapping in run.score. Args: session: Database session @@ -146,10 +149,14 @@ def start_stt_evaluation_batch( if not sample_url_mapping: raise Exception("Failed to generate signed URLs for any audio files") - # Build batch requests using signed URLs directly (with mime type detection) + # Extract sample IDs as keys for batch request tracking + sample_keys = [str(item["sample_id"]) for item in sample_url_mapping] + + # Build batch requests in Gemini JSONL format (with keys embedded) jsonl_data = create_stt_batch_requests( signed_urls=signed_urls, prompt=DEFAULT_TRANSCRIPTION_PROMPT, + keys=sample_keys, ) # Use first provider (STT evaluations use one provider per run) @@ -197,15 +204,13 @@ def start_stt_evaluation_batch( ) raise Exception(f"Batch submission failed: {str(e)}") - # Link batch job to the evaluation run and store sample mapping + # Link batch job to the evaluation run + # Note: sample_file_mapping is no longer needed as we use batch request keys update_stt_run( session=session, run_id=run.id, status="processing", batch_job_id=batch_job.id, - score={ - "sample_file_mapping": sample_url_mapping, - }, ) logger.info( diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index a70f1c1e8..226cf1af7 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -210,9 +210,6 @@ async def process_completed_stt_batch( f"run_id: {run.id}, batch_job_id: {batch_job.id}" ) - sample_file_mapping = run.score.get("sample_file_mapping", []) if run.score else [] - sample_ids = [item["sample_id"] for item in sample_file_mapping] - # Get the STT provider from batch job config stt_provider = batch_job.config.get("stt_provider", "gemini-2.5-pro") @@ -229,7 +226,7 @@ async def process_completed_stt_batch( try: # Download results using GeminiBatchProvider - # Use provider_batch_id to download results + # Keys are embedded in the JSONL response file, no separate mapping needed results = batch_provider.download_batch_results(batch_job.provider_batch_id) logger.info( @@ -237,20 +234,20 @@ async def process_completed_stt_batch( f"batch_job_id: {batch_job.id}, result_count: {len(results)}" ) - # Match results to samples by index + # Match results to samples using key (sample_id) from batch request for batch_result in results: custom_id = batch_result["custom_id"] - # custom_id is the index as string + # custom_id is the sample_id as string (set via key in batch request) try: - index = int(custom_id) + sample_id = int(custom_id) except (ValueError, TypeError): - index = results.index(batch_result) - - if index >= len(sample_ids): + logger.warning( + f"[process_completed_stt_batch] Invalid custom_id | " + f"batch_job_id: {batch_job.id}, custom_id: {custom_id}" + ) + failed_count += 1 continue - sample_id = sample_ids[index] - # Find result record for this sample and provider stmt = select(STTResult).where( STTResult.evaluation_run_id == run.id, @@ -281,9 +278,11 @@ async def process_completed_stt_batch( except Exception as e: logger.error( f"[process_completed_stt_batch] Failed to process batch results | " - f"batch_job_id: {batch_job.id}, error: {str(e)}" + f"batch_job_id: {batch_job.id}, error: {str(e)}", + exc_info=True, ) - failed_count += len(sample_file_mapping) + # Mark all pending results as failed since we couldn't process the batch + raise # Update run status status_counts = count_results_by_status(session=session, run_id=run.id) From ff29dddd84cabd679679a2b030b7834302c2044c Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Mon, 2 Feb 2026 23:20:28 +0530 Subject: [PATCH 14/36] few cleanups --- .../versions/043_add_stt_evaluation_tables.py | 14 + backend/app/core/cloud/storage.py | 8 +- backend/app/tests/core/batch/test_gemini.py | 553 ++++++++++++++++++ 3 files changed, 574 insertions(+), 1 deletion(-) create mode 100644 backend/app/tests/core/batch/test_gemini.py diff --git a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py index 9263d1501..08170e76c 100644 --- a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -325,8 +325,22 @@ def upgrade(): unique=False, ) + # Add unique constraint on evaluation_dataset (name, organization_id, project_id) + # Prevents duplicate dataset names within the same organization/project scope + op.create_unique_constraint( + "uq_evaluation_dataset_name_org_project", + "evaluation_dataset", + ["name", "organization_id", "project_id"], + ) + def downgrade(): + # Drop unique constraint on evaluation_dataset + op.drop_constraint( + "uq_evaluation_dataset_name_org_project", + "evaluation_dataset", + type_="unique", + ) # Drop stt_result table op.drop_index("idx_stt_result_status", table_name="stt_result") op.drop_index("idx_stt_result_feedback", table_name="stt_result") diff --git a/backend/app/core/cloud/storage.py b/backend/app/core/cloud/storage.py index a3247b74b..727380726 100644 --- a/backend/app/core/cloud/storage.py +++ b/backend/app/core/cloud/storage.py @@ -213,13 +213,19 @@ def get_file_size_kb(self, url: str) -> float: ) raise CloudStorageError(f'AWS Error: "{err}" ({url})') from err + # Maximum allowed expiry for signed URLs (24 hours) + MAX_SIGNED_URL_EXPIRY = 86400 + def get_signed_url(self, url: str, expires_in: int = 3600) -> str: """ Generate a signed S3 URL for the given file. :param url: S3 url (e.g., s3://bucket/key) - :param expires_in: Expiry time in seconds (default: 1 hour) + :param expires_in: Expiry time in seconds (default: 1 hour, max: 24 hours) :return: Signed URL as string """ + # Cap expiry at maximum allowed value to prevent excessively long-lived URLs + expires_in = min(expires_in, self.MAX_SIGNED_URL_EXPIRY) + name = SimpleStorageName.from_url(url) try: signed_url = self.aws.client.generate_presigned_url( diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py new file mode 100644 index 000000000..026232ab5 --- /dev/null +++ b/backend/app/tests/core/batch/test_gemini.py @@ -0,0 +1,553 @@ +"""Test cases for GeminiBatchProvider.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from app.core.batch.gemini import ( + BatchJobState, + GeminiBatchProvider, + create_stt_batch_requests, +) + + +class TestGeminiBatchProvider: + """Test cases for GeminiBatchProvider.""" + + @pytest.fixture + def mock_genai_client(self): + """Create a mock Gemini client.""" + return MagicMock() + + @pytest.fixture + def provider(self, mock_genai_client): + """Create a GeminiBatchProvider instance with mock client.""" + return GeminiBatchProvider(client=mock_genai_client) + + @pytest.fixture + def provider_with_model(self, mock_genai_client): + """Create a GeminiBatchProvider with custom model.""" + return GeminiBatchProvider( + client=mock_genai_client, model="models/gemini-2.5-flash" + ) + + def test_initialization_default_model(self, mock_genai_client): + """Test that provider initializes with default model.""" + provider = GeminiBatchProvider(client=mock_genai_client) + assert provider._client == mock_genai_client + assert provider._model == "models/gemini-2.5-pro" + + def test_initialization_custom_model(self, mock_genai_client): + """Test that provider initializes with custom model.""" + provider = GeminiBatchProvider( + client=mock_genai_client, model="models/gemini-2.5-flash" + ) + assert provider._model == "models/gemini-2.5-flash" + + def test_create_batch_success(self, provider, mock_genai_client): + """Test successful batch creation.""" + jsonl_data = [ + {"key": "req-1", "request": {"contents": [{"parts": [{"text": "test"}]}]}}, + {"key": "req-2", "request": {"contents": [{"parts": [{"text": "test2"}]}]}}, + ] + config = { + "display_name": "test-batch", + "model": "models/gemini-2.5-pro", + } + + # Mock file upload + mock_genai_client.files.upload.return_value = MagicMock( + name="files/uploaded-123" + ) + + # Mock batch creation + mock_batch_job = MagicMock() + mock_batch_job.name = "batches/batch-xyz789" + mock_batch_job.state = MagicMock(name="JOB_STATE_PENDING") + mock_genai_client.batches.create.return_value = mock_batch_job + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + result = provider.create_batch(jsonl_data, config) + + assert result["provider_batch_id"] == "batches/batch-xyz789" + assert result["provider_file_id"] == "files/uploaded-123" + assert result["provider_status"] == "JOB_STATE_PENDING" + assert result["total_items"] == 2 + + def test_create_batch_with_default_config(self, provider, mock_genai_client): + """Test batch creation with default configuration values.""" + jsonl_data = [{"key": "req-1", "request": {}}] + config = {} + + mock_genai_client.files.upload.return_value = MagicMock( + name="files/uploaded-456" + ) + + mock_batch_job = MagicMock() + mock_batch_job.name = "batches/batch-123" + mock_batch_job.state = MagicMock(name="JOB_STATE_PENDING") + mock_genai_client.batches.create.return_value = mock_batch_job + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + result = provider.create_batch(jsonl_data, config) + + assert result["total_items"] == 1 + mock_genai_client.batches.create.assert_called_once() + + def test_create_batch_file_upload_error(self, provider, mock_genai_client): + """Test handling of file upload error during batch creation.""" + jsonl_data = [{"key": "req-1", "request": {}}] + config = {"display_name": "test"} + + mock_genai_client.files.upload.side_effect = Exception("File upload failed") + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.create_batch(jsonl_data, config) + + assert "File upload failed" in str(exc_info.value) + + def test_create_batch_batch_creation_error(self, provider, mock_genai_client): + """Test handling of batch creation error.""" + jsonl_data = [{"key": "req-1", "request": {}}] + config = {"display_name": "test"} + + mock_genai_client.files.upload.return_value = MagicMock( + name="files/uploaded-123" + ) + mock_genai_client.batches.create.side_effect = Exception( + "Batch creation failed" + ) + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.create_batch(jsonl_data, config) + + assert "Batch creation failed" in str(exc_info.value) + + def test_get_batch_status_pending(self, provider, mock_genai_client): + """Test getting status of a pending batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_PENDING") + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + mock_genai_client.batches.get.assert_called_once_with(name=batch_id) + assert result["provider_status"] == "JOB_STATE_PENDING" + assert result["provider_output_file_id"] == batch_id + assert "error_message" not in result + + def test_get_batch_status_succeeded(self, provider, mock_genai_client): + """Test getting status of a succeeded batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_SUCCEEDED" + assert "error_message" not in result + + def test_get_batch_status_failed(self, provider, mock_genai_client): + """Test getting status of a failed batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_FAILED") + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_FAILED" + assert "error_message" in result + assert "Batch JOB_STATE_FAILED" in result["error_message"] + + def test_get_batch_status_cancelled(self, provider, mock_genai_client): + """Test getting status of a cancelled batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_CANCELLED") + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_CANCELLED" + assert "error_message" in result + + def test_get_batch_status_expired(self, provider, mock_genai_client): + """Test getting status of an expired batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_EXPIRED") + mock_genai_client.batches.get.return_value = mock_batch_job + + result = provider.get_batch_status(batch_id) + + assert result["provider_status"] == "JOB_STATE_EXPIRED" + assert "error_message" in result + + def test_get_batch_status_error(self, provider, mock_genai_client): + """Test handling of error when retrieving batch status.""" + batch_id = "batches/batch-xyz789" + + mock_genai_client.batches.get.side_effect = Exception("API connection failed") + + with pytest.raises(Exception) as exc_info: + provider.get_batch_status(batch_id) + + assert "API connection failed" in str(exc_info.value) + + def test_download_batch_results_success(self, provider, mock_genai_client): + """Test successful download of batch results.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = ( + '{"key":"req-1","response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}\n' + '{"key":"req-2","response":{"candidates":[{"content":{"parts":[{"text":"World"}]}}]}}' + ) + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + mock_genai_client.batches.get.assert_called_once_with(name=batch_id) + mock_genai_client.files.download.assert_called_once_with( + file="files/output-123" + ) + assert len(results) == 2 + assert results[0]["custom_id"] == "req-1" + assert results[0]["response"]["text"] == "Hello" + assert results[1]["custom_id"] == "req-2" + assert results[1]["response"]["text"] == "World" + + def test_download_batch_results_with_direct_text_response( + self, provider, mock_genai_client + ): + """Test downloading results with direct text in response.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = '{"key":"req-1","response":{"text":"Direct text"}}' + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + assert len(results) == 1 + assert results[0]["response"]["text"] == "Direct text" + + def test_download_batch_results_with_errors(self, provider, mock_genai_client): + """Test downloading batch results that contain errors.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = ( + '{"key":"req-1","response":{"text":"Success"}}\n' + '{"key":"req-2","error":{"message":"Invalid request"}}' + ) + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + assert len(results) == 2 + assert results[0]["custom_id"] == "req-1" + assert results[0]["response"] is not None + assert results[0]["error"] is None + assert results[1]["custom_id"] == "req-2" + assert results[1]["response"] is None + assert "Invalid request" in results[1]["error"] + + def test_download_batch_results_batch_not_complete( + self, provider, mock_genai_client + ): + """Test error when trying to download results from incomplete batch.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_RUNNING") + mock_genai_client.batches.get.return_value = mock_batch_job + + with pytest.raises(ValueError) as exc_info: + provider.download_batch_results(batch_id) + + assert "Batch job not complete" in str(exc_info.value) + + def test_download_batch_results_malformed_json(self, provider, mock_genai_client): + """Test handling of malformed JSON in batch results.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.dest = MagicMock() + mock_batch_job.dest.file_name = "files/output-123" + mock_genai_client.batches.get.return_value = mock_batch_job + + jsonl_content = ( + '{"key":"req-1","response":{"text":"Valid"}}\n' + "this is not valid json\n" + '{"key":"req-3","response":{"text":"Also valid"}}' + ) + mock_genai_client.files.download.return_value = jsonl_content.encode("utf-8") + + results = provider.download_batch_results(batch_id) + + # Should skip the malformed line and process the rest + assert len(results) == 2 + assert results[0]["custom_id"] == "req-1" + assert results[1]["custom_id"] == "req-3" + + def test_download_batch_results_no_dest_file(self, provider, mock_genai_client): + """Test handling when batch has no destination file.""" + batch_id = "batches/batch-xyz789" + + mock_batch_job = MagicMock() + mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.dest = None + mock_genai_client.batches.get.return_value = mock_batch_job + + results = provider.download_batch_results(batch_id) + + assert len(results) == 0 + + def test_upload_file_success(self, provider, mock_genai_client): + """Test successful file upload.""" + content = '{"key":"req-1","request":{}}\n{"key":"req-2","request":{}}' + + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-abc123" + mock_genai_client.files.upload.return_value = mock_uploaded_file + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_temp_file = MagicMock() + mock_temp_file.name = "/tmp/test.jsonl" + mock_temp.return_value.__enter__.return_value = mock_temp_file + + with patch("os.unlink"): + file_name = provider.upload_file(content, purpose="batch") + + assert file_name == "files/uploaded-abc123" + mock_genai_client.files.upload.assert_called_once() + + def test_upload_file_error(self, provider, mock_genai_client): + """Test handling of error during file upload.""" + content = '{"key":"req-1"}' + + mock_genai_client.files.upload.side_effect = Exception("Upload quota exceeded") + + with patch("tempfile.NamedTemporaryFile"): + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.upload_file(content) + + assert "Upload quota exceeded" in str(exc_info.value) + + def test_download_file_success(self, provider, mock_genai_client): + """Test successful file download.""" + file_id = "files/abc123" + expected_content = '{"key":"req-1","response":{"text":"test"}}' + + mock_genai_client.files.download.return_value = expected_content.encode("utf-8") + + content = provider.download_file(file_id) + + mock_genai_client.files.download.assert_called_once_with(file=file_id) + assert content == expected_content + + def test_download_file_unicode_content(self, provider, mock_genai_client): + """Test downloading file with unicode content.""" + file_id = "files/abc123" + expected_content = '{"text":"Hello δΈ–η•Œ 🌍"}' + + mock_genai_client.files.download.return_value = expected_content.encode("utf-8") + + content = provider.download_file(file_id) + + assert content == expected_content + assert "δΈ–η•Œ" in content + assert "🌍" in content + + def test_download_file_error(self, provider, mock_genai_client): + """Test handling of error during file download.""" + file_id = "files/abc123" + + mock_genai_client.files.download.side_effect = Exception("File not found") + + with pytest.raises(Exception) as exc_info: + provider.download_file(file_id) + + assert "File not found" in str(exc_info.value) + + +class TestBatchJobState: + """Test cases for BatchJobState enum.""" + + def test_batch_job_states(self): + """Test that all batch job states have correct values.""" + assert BatchJobState.PENDING.value == "JOB_STATE_PENDING" + assert BatchJobState.RUNNING.value == "JOB_STATE_RUNNING" + assert BatchJobState.SUCCEEDED.value == "JOB_STATE_SUCCEEDED" + assert BatchJobState.FAILED.value == "JOB_STATE_FAILED" + assert BatchJobState.CANCELLED.value == "JOB_STATE_CANCELLED" + assert BatchJobState.EXPIRED.value == "JOB_STATE_EXPIRED" + + +class TestCreateSTTBatchRequests: + """Test cases for create_stt_batch_requests function.""" + + def test_create_requests_with_keys(self): + """Test creating batch requests with custom keys.""" + signed_urls = [ + "https://bucket.s3.amazonaws.com/audio1.mp3?signature=abc", + "https://bucket.s3.amazonaws.com/audio2.wav?signature=def", + ] + prompt = "Transcribe this audio file." + keys = ["sample-1", "sample-2"] + + requests = create_stt_batch_requests(signed_urls, prompt, keys=keys) + + assert len(requests) == 2 + assert requests[0]["key"] == "sample-1" + assert requests[1]["key"] == "sample-2" + + # Verify structure + assert "request" in requests[0] + assert "contents" in requests[0]["request"] + contents = requests[0]["request"]["contents"] + assert len(contents) == 1 + assert contents[0]["role"] == "user" + assert len(contents[0]["parts"]) == 2 + assert contents[0]["parts"][0]["text"] == prompt + assert "file_data" in contents[0]["parts"][1] + + def test_create_requests_without_keys(self): + """Test creating batch requests without keys (auto-generated).""" + signed_urls = [ + "https://bucket.s3.amazonaws.com/audio.mp3?signature=xyz", + ] + prompt = "Transcribe." + + requests = create_stt_batch_requests(signed_urls, prompt) + + assert len(requests) == 1 + assert requests[0]["key"] == "0" + + def test_create_requests_mime_type_detection(self): + """Test that MIME types are correctly detected from URLs.""" + signed_urls = [ + "https://bucket.s3.amazonaws.com/audio.mp3?sig=1", + "https://bucket.s3.amazonaws.com/audio.wav?sig=2", + "https://bucket.s3.amazonaws.com/audio.m4a?sig=3", + ] + prompt = "Transcribe." + + requests = create_stt_batch_requests(signed_urls, prompt) + + assert ( + requests[0]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] + == "audio/mpeg" + ) + assert ( + requests[1]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] + == "audio/x-wav" + ) + assert ( + requests[2]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] + == "audio/mp4" + ) + + def test_create_requests_key_length_mismatch(self): + """Test that mismatched keys and URLs raise error.""" + signed_urls = [ + "https://example.com/audio1.mp3", + "https://example.com/audio2.mp3", + ] + keys = ["only-one-key"] + prompt = "Transcribe." + + with pytest.raises(ValueError) as exc_info: + create_stt_batch_requests(signed_urls, prompt, keys=keys) + + assert "Length of keys" in str(exc_info.value) + + def test_create_requests_file_uri_preserved(self): + """Test that signed URLs are preserved in file_uri.""" + signed_url = "https://bucket.s3.amazonaws.com/audio.mp3?X-Amz-Signature=abc123&X-Amz-Expires=3600" + prompt = "Transcribe." + + requests = create_stt_batch_requests([signed_url], prompt) + + file_uri = requests[0]["request"]["contents"][0]["parts"][1]["file_data"][ + "file_uri" + ] + assert file_uri == signed_url + assert "X-Amz-Signature" in file_uri + + +class TestExtractTextFromResponseDict: + """Test cases for _extract_text_from_response_dict static method.""" + + def test_extract_direct_text(self): + """Test extracting text from direct text field.""" + response = {"text": "Hello world"} + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "Hello world" + + def test_extract_from_candidates(self): + """Test extracting text from candidates structure.""" + response = { + "candidates": [ + { + "content": { + "parts": [ + {"text": "Part 1"}, + {"text": " Part 2"}, + ] + } + } + ] + } + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "Part 1 Part 2" + + def test_extract_empty_response(self): + """Test extracting text from empty response.""" + response = {} + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "" + + def test_extract_multiple_candidates(self): + """Test extracting text from multiple candidates.""" + response = { + "candidates": [ + {"content": {"parts": [{"text": "First"}]}}, + {"content": {"parts": [{"text": "Second"}]}}, + ] + } + text = GeminiBatchProvider._extract_text_from_response_dict(response) + assert text == "FirstSecond" From cd979fdb88cda2ee21a5631ed0b971f1a86dfcc9 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 15:53:07 +0530 Subject: [PATCH 15/36] updated migration --- .../versions/043_add_stt_evaluation_tables.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py index 08170e76c..9263d1501 100644 --- a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -325,22 +325,8 @@ def upgrade(): unique=False, ) - # Add unique constraint on evaluation_dataset (name, organization_id, project_id) - # Prevents duplicate dataset names within the same organization/project scope - op.create_unique_constraint( - "uq_evaluation_dataset_name_org_project", - "evaluation_dataset", - ["name", "organization_id", "project_id"], - ) - def downgrade(): - # Drop unique constraint on evaluation_dataset - op.drop_constraint( - "uq_evaluation_dataset_name_org_project", - "evaluation_dataset", - type_="unique", - ) # Drop stt_result table op.drop_index("idx_stt_result_status", table_name="stt_result") op.drop_index("idx_stt_result_feedback", table_name="stt_result") From b6c633af7657b615a456e968e6c882e39c3ca0ac Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 16:08:56 +0530 Subject: [PATCH 16/36] cleanup config for batch --- backend/app/crud/stt_evaluations/batch.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index 968e88ed0..405e3aec5 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -176,11 +176,7 @@ def start_stt_evaluation_batch( organization_id=org_id, project_id=project_id, jsonl_data=jsonl_data, - config={ - "display_name": f"stt-eval-{run.id}-{provider}", - "model": model, - "stt_provider": provider, - }, + config={"model": provider}, ) logger.info( From b6e664914afc7593e6ce83177d583a7b33409ea5 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 16:22:11 +0530 Subject: [PATCH 17/36] moved documentation to separate folder --- .../api/docs/stt_evaluation/create_dataset.md | 5 +++++ .../app/api/docs/stt_evaluation/get_dataset.md | 1 + .../app/api/docs/stt_evaluation/get_result.md | 1 + backend/app/api/docs/stt_evaluation/get_run.md | 1 + .../api/docs/stt_evaluation/list_datasets.md | 1 + .../app/api/docs/stt_evaluation/list_runs.md | 1 + .../api/docs/stt_evaluation/start_evaluation.md | 8 ++++++++ .../api/docs/stt_evaluation/update_feedback.md | 5 +++++ .../app/api/docs/stt_evaluation/upload_audio.md | 7 +++++++ .../app/api/routes/stt_evaluations/dataset.py | 14 ++++---------- .../api/routes/stt_evaluations/evaluation.py | 17 ++++------------- backend/app/api/routes/stt_evaluations/files.py | 12 ++---------- .../app/api/routes/stt_evaluations/result.py | 12 +++--------- 13 files changed, 43 insertions(+), 42 deletions(-) create mode 100644 backend/app/api/docs/stt_evaluation/create_dataset.md create mode 100644 backend/app/api/docs/stt_evaluation/get_dataset.md create mode 100644 backend/app/api/docs/stt_evaluation/get_result.md create mode 100644 backend/app/api/docs/stt_evaluation/get_run.md create mode 100644 backend/app/api/docs/stt_evaluation/list_datasets.md create mode 100644 backend/app/api/docs/stt_evaluation/list_runs.md create mode 100644 backend/app/api/docs/stt_evaluation/start_evaluation.md create mode 100644 backend/app/api/docs/stt_evaluation/update_feedback.md create mode 100644 backend/app/api/docs/stt_evaluation/upload_audio.md diff --git a/backend/app/api/docs/stt_evaluation/create_dataset.md b/backend/app/api/docs/stt_evaluation/create_dataset.md new file mode 100644 index 000000000..c17b3a06c --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/create_dataset.md @@ -0,0 +1,5 @@ +Create a new STT evaluation dataset with audio samples. + +Each sample requires: +- **object_store_url**: S3 URL of the audio file (from /evaluations/stt/files/audio endpoint) +- **ground_truth**: Reference transcription (optional, for WER/CER metrics) diff --git a/backend/app/api/docs/stt_evaluation/get_dataset.md b/backend/app/api/docs/stt_evaluation/get_dataset.md new file mode 100644 index 000000000..e965504e7 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/get_dataset.md @@ -0,0 +1 @@ +Get an STT dataset with its samples. diff --git a/backend/app/api/docs/stt_evaluation/get_result.md b/backend/app/api/docs/stt_evaluation/get_result.md new file mode 100644 index 000000000..728710ec5 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/get_result.md @@ -0,0 +1 @@ +Get a single STT transcription result. diff --git a/backend/app/api/docs/stt_evaluation/get_run.md b/backend/app/api/docs/stt_evaluation/get_run.md new file mode 100644 index 000000000..ce88c37d4 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/get_run.md @@ -0,0 +1 @@ +Get an STT evaluation run with its results. diff --git a/backend/app/api/docs/stt_evaluation/list_datasets.md b/backend/app/api/docs/stt_evaluation/list_datasets.md new file mode 100644 index 000000000..32a3671ad --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/list_datasets.md @@ -0,0 +1 @@ +List all STT evaluation datasets for the current project. diff --git a/backend/app/api/docs/stt_evaluation/list_runs.md b/backend/app/api/docs/stt_evaluation/list_runs.md new file mode 100644 index 000000000..40ffa1ff2 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/list_runs.md @@ -0,0 +1 @@ +List all STT evaluation runs for the current project. diff --git a/backend/app/api/docs/stt_evaluation/start_evaluation.md b/backend/app/api/docs/stt_evaluation/start_evaluation.md new file mode 100644 index 000000000..cfc2e03ed --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/start_evaluation.md @@ -0,0 +1,8 @@ +Start an STT evaluation run on a dataset. + +The evaluation will: +1. Process each audio sample through the specified providers +2. Generate transcriptions using Gemini Batch API +3. Store results for human review + +**Supported providers:** gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash diff --git a/backend/app/api/docs/stt_evaluation/update_feedback.md b/backend/app/api/docs/stt_evaluation/update_feedback.md new file mode 100644 index 000000000..71a44ce81 --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/update_feedback.md @@ -0,0 +1,5 @@ +Update human feedback on an STT transcription result. + +**Fields:** +- **is_correct**: Boolean indicating if the transcription is correct +- **comment**: Optional feedback comment explaining issues or observations diff --git a/backend/app/api/docs/stt_evaluation/upload_audio.md b/backend/app/api/docs/stt_evaluation/upload_audio.md new file mode 100644 index 000000000..e4fb4429b --- /dev/null +++ b/backend/app/api/docs/stt_evaluation/upload_audio.md @@ -0,0 +1,7 @@ +Upload a single audio file to S3 for STT evaluation. + +**Supported formats:** mp3, wav, flac, m4a, ogg, webm + +**Maximum file size:** 200 MB + +Returns the S3 URL which can be used when creating an STT dataset. diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index f3d048b93..a1c82efb1 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -19,7 +19,7 @@ STTSamplePublic, ) from app.services.stt_evaluations.dataset import upload_stt_dataset -from app.utils import APIResponse +from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -31,13 +31,7 @@ response_model=APIResponse[STTDatasetPublic], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Create STT dataset", - description=""" -Create a new STT evaluation dataset with audio samples. - -Each sample requires: -- **object_store_url**: S3 URL of the audio file (from /evaluations/stt/files/audio endpoint) -- **ground_truth**: Reference transcription (optional, for WER/CER metrics) -""", + description=load_description("stt_evaluation/create_dataset.md"), ) def create_dataset( _session: SessionDep, @@ -78,7 +72,7 @@ def create_dataset( response_model=APIResponse[list[STTDatasetPublic]], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="List STT datasets", - description="List all STT evaluation datasets for the current project.", + description=load_description("stt_evaluation/list_datasets.md"), ) def list_datasets( _session: SessionDep, @@ -106,7 +100,7 @@ def list_datasets( response_model=APIResponse[STTDatasetWithSamples], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Get STT dataset", - description="Get an STT dataset with its samples.", + description=load_description("stt_evaluation/get_dataset.md"), ) def get_dataset( _session: SessionDep, diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index 2dcff4ca2..ec9120305 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -23,7 +23,7 @@ STTEvaluationRunPublic, STTEvaluationRunWithResults, ) -from app.utils import APIResponse +from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -35,16 +35,7 @@ response_model=APIResponse[STTEvaluationRunPublic], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Start STT evaluation", - description=""" -Start an STT evaluation run on a dataset. - -The evaluation will: -1. Process each audio sample through the specified providers -2. Generate transcriptions using Gemini Batch API -3. Store results for human review - -**Supported providers:** gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash -""", + description=load_description("stt_evaluation/start_evaluation.md"), ) def start_stt_evaluation( _session: SessionDep, @@ -177,7 +168,7 @@ def start_stt_evaluation( response_model=APIResponse[list[STTEvaluationRunPublic]], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="List STT evaluation runs", - description="List all STT evaluation runs for the current project.", + description=load_description("stt_evaluation/list_runs.md"), ) def list_stt_evaluation_runs( _session: SessionDep, @@ -209,7 +200,7 @@ def list_stt_evaluation_runs( response_model=APIResponse[STTEvaluationRunWithResults], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Get STT evaluation run", - description="Get an STT evaluation run with its results.", + description=load_description("stt_evaluation/get_run.md"), ) def get_stt_evaluation_run( _session: SessionDep, diff --git a/backend/app/api/routes/stt_evaluations/files.py b/backend/app/api/routes/stt_evaluations/files.py index 08decb22f..fe37bc397 100644 --- a/backend/app/api/routes/stt_evaluations/files.py +++ b/backend/app/api/routes/stt_evaluations/files.py @@ -8,7 +8,7 @@ from app.api.permissions import Permission, require_permission from app.models.stt_evaluation import AudioUploadResponse from app.services.stt_evaluations.audio import upload_audio_file -from app.utils import APIResponse +from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -20,15 +20,7 @@ response_model=APIResponse[AudioUploadResponse], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Upload audio file", - description=""" -Upload a single audio file to S3 for STT evaluation. - -**Supported formats:** mp3, wav, flac, m4a, ogg, webm - -**Maximum file size:** 200 MB - -Returns the S3 URL which can be used when creating an STT dataset. -""", + description=load_description("stt_evaluation/upload_audio.md"), ) def upload_audio( _session: SessionDep, diff --git a/backend/app/api/routes/stt_evaluations/result.py b/backend/app/api/routes/stt_evaluations/result.py index 3fd994e25..5ba1d0597 100644 --- a/backend/app/api/routes/stt_evaluations/result.py +++ b/backend/app/api/routes/stt_evaluations/result.py @@ -14,7 +14,7 @@ STTFeedbackUpdate, STTResultPublic, ) -from app.utils import APIResponse +from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -26,13 +26,7 @@ response_model=APIResponse[STTResultPublic], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Update human feedback", - description=""" -Update human feedback on an STT transcription result. - -**Fields:** -- **is_correct**: Boolean indicating if the transcription is correct -- **comment**: Optional feedback comment explaining issues or observations -""", + description=load_description("stt_evaluation/update_feedback.md"), ) def update_result_feedback( _session: SessionDep, @@ -94,7 +88,7 @@ def update_result_feedback( response_model=APIResponse[STTResultPublic], dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], summary="Get STT result", - description="Get a single STT transcription result.", + description=load_description("stt_evaluation/get_result.md"), ) def get_result( _session: SessionDep, From 719584d4730fcfb76530e7d0a6160e0fa5da0104 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 19:28:06 +0530 Subject: [PATCH 18/36] updated score format in stt result --- .../versions/043_add_stt_evaluation_tables.py | 12 ++---- backend/app/models/stt_evaluation.py | 41 ++++--------------- 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py index 9263d1501..c547302eb 100644 --- a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -204,16 +204,10 @@ def upgrade(): comment="Result status: pending, completed, failed", ), sa.Column( - "wer", - sa.Float(), - nullable=True, - comment="Word Error Rate (null for Phase 1)", - ), - sa.Column( - "cer", - sa.Float(), + "score", + postgresql.JSONB(astext_type=sa.Text()), nullable=True, - comment="Character Error Rate (null for Phase 1)", + comment="Evaluation metrics (e.g., wer, cer, mer, wil) - extensible for future metrics", ), sa.Column( "is_correct", diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index a5749f5d3..033f38813 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -34,9 +34,6 @@ class STTResultStatus(str, Enum): FAILED = "failed" -# Database Models - - class STTSample(SQLModel, table=True): """Database table for STT audio samples within a dataset.""" @@ -48,13 +45,11 @@ class STTSample(SQLModel, table=True): sa_column_kwargs={"comment": "Unique identifier for the STT sample"}, ) - # Audio file reference object_store_url: str = SQLField( description="S3 URL of the audio file", sa_column_kwargs={"comment": "S3 URL of the audio file"}, ) - # Language (can be different per sample within a dataset) language: str | None = SQLField( default=None, max_length=10, @@ -62,7 +57,6 @@ class STTSample(SQLModel, table=True): sa_column_kwargs={"comment": "ISO 639-1 language code for this sample"}, ) - # Ground truth transcription (optional, for evaluation) ground_truth: str | None = SQLField( default=None, sa_column=Column( @@ -73,7 +67,6 @@ class STTSample(SQLModel, table=True): description="Reference transcription for comparison", ) - # Audio metadata sample_metadata: dict[str, Any] | None = SQLField( default_factory=dict, sa_column=Column( @@ -84,7 +77,6 @@ class STTSample(SQLModel, table=True): description="Additional metadata about the audio sample", ) - # Foreign keys dataset_id: int = SQLField( foreign_key="evaluation_dataset.id", nullable=False, @@ -104,7 +96,6 @@ class STTSample(SQLModel, table=True): sa_column_kwargs={"comment": "Reference to the project"}, ) - # Timestamps inserted_at: datetime = SQLField( default_factory=now, nullable=False, @@ -116,7 +107,6 @@ class STTSample(SQLModel, table=True): sa_column_kwargs={"comment": "Timestamp when the sample was last updated"}, ) - # Relationships dataset: "EvaluationDataset" = Relationship() organization: "Organization" = Relationship() project: "Project" = Relationship() @@ -134,7 +124,6 @@ class STTResult(SQLModel, table=True): sa_column_kwargs={"comment": "Unique identifier for the STT result"}, ) - # Transcription output transcription: str | None = SQLField( default=None, sa_column=Column( @@ -145,14 +134,12 @@ class STTResult(SQLModel, table=True): description="Generated transcription from STT provider", ) - # Provider info provider: str = SQLField( max_length=50, description="STT provider used (e.g., gemini-2.5-pro)", sa_column_kwargs={"comment": "STT provider used (e.g., gemini-2.5-pro)"}, ) - # Status status: str = SQLField( default=STTResultStatus.PENDING.value, max_length=20, @@ -160,19 +147,16 @@ class STTResult(SQLModel, table=True): sa_column_kwargs={"comment": "Result status: pending, completed, failed"}, ) - # Metrics (null for Phase 1) - wer: float | None = SQLField( - default=None, - description="Word Error Rate (null for Phase 1)", - sa_column_kwargs={"comment": "Word Error Rate (null for Phase 1)"}, - ) - cer: float | None = SQLField( + score: dict[str, Any] | None = SQLField( default=None, - description="Character Error Rate (null for Phase 1)", - sa_column_kwargs={"comment": "Character Error Rate (null for Phase 1)"}, + sa_column=Column( + JSONB, + nullable=True, + comment="Evaluation metrics (e.g., wer, cer, mer, wil) - extensible for future metrics", + ), + description="Evaluation metrics such as WER, CER, etc.", ) - # Human feedback is_correct: bool | None = SQLField( default=None, description="Human feedback: transcription correctness", @@ -190,7 +174,6 @@ class STTResult(SQLModel, table=True): description="Human feedback comment", ) - # Provider response metadata provider_metadata: dict[str, Any] | None = SQLField( default_factory=dict, sa_column=Column( @@ -201,7 +184,6 @@ class STTResult(SQLModel, table=True): description="Provider-specific response metadata", ) - # Error message if failed error_message: str | None = SQLField( default=None, sa_column=Column( @@ -212,7 +194,6 @@ class STTResult(SQLModel, table=True): description="Error message if transcription failed", ) - # Foreign keys stt_sample_id: int = SQLField( foreign_key="stt_sample.id", nullable=False, @@ -238,7 +219,6 @@ class STTResult(SQLModel, table=True): sa_column_kwargs={"comment": "Reference to the project"}, ) - # Timestamps inserted_at: datetime = SQLField( default_factory=now, nullable=False, @@ -250,16 +230,12 @@ class STTResult(SQLModel, table=True): sa_column_kwargs={"comment": "Timestamp when the result was last updated"}, ) - # Relationships sample: "STTSample" = Relationship(back_populates="results") evaluation_run: "EvaluationRun" = Relationship() organization: "Organization" = Relationship() project: "Project" = Relationship() -# Pydantic Models for API - - class STTSampleCreate(BaseModel): """Request model for creating an STT sample.""" @@ -291,8 +267,7 @@ class STTResultPublic(BaseModel): transcription: str | None provider: str status: str - wer: float | None - cer: float | None + score: dict[str, Any] | None is_correct: bool | None comment: str | None provider_metadata: dict[str, Any] | None From bf0b4c2f8568f423ffee10dd69829ebcff401cae Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 19:42:18 +0530 Subject: [PATCH 19/36] cleaner dataset sample count --- .../app/api/routes/stt_evaluations/dataset.py | 5 +- .../api/routes/stt_evaluations/evaluation.py | 6 +- backend/app/crud/stt_evaluations/__init__.py | 2 - backend/app/crud/stt_evaluations/dataset.py | 113 ++++++------------ .../app/services/stt_evaluations/dataset.py | 5 +- 5 files changed, 41 insertions(+), 90 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index a1c82efb1..6af50d2ea 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -10,7 +10,6 @@ get_stt_dataset_by_id, list_stt_datasets, get_samples_by_dataset_id, - get_sample_count_for_dataset, ) from app.models.stt_evaluation import ( STTDatasetCreate, @@ -150,9 +149,7 @@ def get_dataset( for s in sample_records ] else: - samples_total = get_sample_count_for_dataset( - session=_session, dataset_id=dataset_id - ) + samples_total = (dataset.dataset_metadata or {}).get("sample_count", 0) return APIResponse.success_response( data=STTDatasetWithSamples( diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index ec9120305..5b54d0082 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -13,7 +13,6 @@ get_stt_dataset_by_id, get_stt_run_by_id, list_stt_runs, - get_sample_count_for_dataset, start_stt_evaluation_batch, update_stt_run, ) @@ -60,10 +59,7 @@ def start_stt_evaluation( if not dataset: raise HTTPException(status_code=404, detail="Dataset not found") - # Get sample count for total items - sample_count = get_sample_count_for_dataset( - session=_session, dataset_id=run_create.dataset_id - ) + sample_count = (dataset.dataset_metadata or {}).get("sample_count", 0) if sample_count == 0: raise HTTPException(status_code=400, detail="Dataset has no samples") diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py index 196b8aec2..2202ca02a 100644 --- a/backend/app/crud/stt_evaluations/__init__.py +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -8,7 +8,6 @@ get_stt_dataset_by_id, list_stt_datasets, get_samples_by_dataset_id, - get_sample_count_for_dataset, ) from .run import ( create_stt_run, @@ -36,7 +35,6 @@ "get_stt_dataset_by_id", "list_stt_datasets", "get_samples_by_dataset_id", - "get_sample_count_for_dataset", # Run "create_stt_run", "get_stt_run_by_id", diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index 3014dcc0b..bf8997d4f 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -123,59 +123,47 @@ def _extract_metadata_from_url(url: str) -> dict[str, Any]: def create_stt_samples( *, session: Session, - dataset_id: int, - org_id: int, - project_id: int, + dataset: EvaluationDataset, samples: list[STTSampleCreate], - language: str | None = None, ) -> list[STTSample]: """Create STT samples for a dataset. Args: session: Database session - dataset_id: Parent dataset ID - org_id: Organization ID - project_id: Project ID + dataset: Parent dataset (must have sample_count in dataset_metadata) samples: List of sample data - language: Language code from parent dataset Returns: list[STTSample]: Created samples """ logger.info( f"[create_stt_samples] Creating STT samples | " - f"dataset_id: {dataset_id}, sample_count: {len(samples)}, language: {language}" + f"dataset_id: {dataset.id}, sample_count: {len(samples)}" ) - created_samples = [] - - for sample_data in samples: - # Extract metadata from URL - sample_metadata = _extract_metadata_from_url(sample_data.object_store_url) - - sample = STTSample( + timestamp = now() + created_samples = [ + STTSample( object_store_url=sample_data.object_store_url, ground_truth=sample_data.ground_truth, - language=language, - sample_metadata=sample_metadata, - dataset_id=dataset_id, - organization_id=org_id, - project_id=project_id, - inserted_at=now(), - updated_at=now(), + language=dataset.language, + sample_metadata=_extract_metadata_from_url(sample_data.object_store_url), + dataset_id=dataset.id, + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=timestamp, + updated_at=timestamp, ) - session.add(sample) - created_samples.append(sample) + for sample_data in samples + ] + session.add_all(created_samples) + session.flush() session.commit() - # Refresh all samples to get IDs - for sample in created_samples: - session.refresh(sample) - logger.info( f"[create_stt_samples] STT samples created | " - f"dataset_id: {dataset_id}, created_count: {len(created_samples)}" + f"dataset_id: {dataset.id}, created_count: {len(created_samples)}" ) return created_samples @@ -229,22 +217,18 @@ def list_stt_datasets( Returns: tuple[list[STTDatasetPublic], int]: Datasets and total count """ - # Get total count - count_stmt = select(func.count(EvaluationDataset.id)).where( + base_filter = ( EvaluationDataset.organization_id == org_id, EvaluationDataset.project_id == project_id, EvaluationDataset.type == EvaluationType.STT.value, ) + + count_stmt = select(func.count(EvaluationDataset.id)).where(*base_filter) total = session.exec(count_stmt).one() - # Get datasets statement = ( select(EvaluationDataset) - .where( - EvaluationDataset.organization_id == org_id, - EvaluationDataset.project_id == project_id, - EvaluationDataset.type == EvaluationType.STT.value, - ) + .where(*base_filter) .order_by(EvaluationDataset.inserted_at.desc()) .offset(offset) .limit(limit) @@ -252,48 +236,27 @@ def list_stt_datasets( datasets = session.exec(statement).all() - # Convert to public models with sample counts - result = [] - for dataset in datasets: - sample_count = get_sample_count_for_dataset( - session=session, dataset_id=dataset.id - ) - result.append( - STTDatasetPublic( - id=dataset.id, - name=dataset.name, - description=dataset.description, - type=dataset.type, - language=dataset.language, - object_store_url=dataset.object_store_url, - dataset_metadata=dataset.dataset_metadata, - sample_count=sample_count, - organization_id=dataset.organization_id, - project_id=dataset.project_id, - inserted_at=dataset.inserted_at, - updated_at=dataset.updated_at, - ) + result = [ + STTDatasetPublic( + id=dataset.id, + name=dataset.name, + description=dataset.description, + type=dataset.type, + language=dataset.language, + object_store_url=dataset.object_store_url, + dataset_metadata=dataset.dataset_metadata, + sample_count=(dataset.dataset_metadata or {}).get("sample_count", 0), + organization_id=dataset.organization_id, + project_id=dataset.project_id, + inserted_at=dataset.inserted_at, + updated_at=dataset.updated_at, ) + for dataset in datasets + ] return result, total -def get_sample_count_for_dataset(*, session: Session, dataset_id: int) -> int: - """Get the number of samples in a dataset. - - Args: - session: Database session - dataset_id: Dataset ID - - Returns: - int: Sample count - """ - statement = select(func.count(STTSample.id)).where( - STTSample.dataset_id == dataset_id - ) - return session.exec(statement).one() - - def get_samples_by_dataset_id( *, session: Session, diff --git a/backend/app/services/stt_evaluations/dataset.py b/backend/app/services/stt_evaluations/dataset.py index 5fda1ec63..e060f1222 100644 --- a/backend/app/services/stt_evaluations/dataset.py +++ b/backend/app/services/stt_evaluations/dataset.py @@ -91,11 +91,8 @@ def upload_stt_dataset( # Step 4: Create sample records created_samples = create_stt_samples( session=session, - dataset_id=dataset.id, - org_id=organization_id, - project_id=project_id, + dataset=dataset, samples=samples, - language=language, ) logger.info( From 68e682140bf5f45c7818fa540483e6410b44cc89 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 19:46:39 +0530 Subject: [PATCH 20/36] got rid of redundant sample count --- backend/app/api/routes/stt_evaluations/dataset.py | 5 +---- backend/app/crud/stt_evaluations/dataset.py | 1 - backend/app/models/stt_evaluation.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index 6af50d2ea..791022a63 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -121,7 +121,7 @@ def get_dataset( raise HTTPException(status_code=404, detail="Dataset not found") samples = [] - samples_total = 0 + samples_total = (dataset.dataset_metadata or {}).get("sample_count", 0) if include_samples: sample_records, samples_total = get_samples_by_dataset_id( @@ -148,8 +148,6 @@ def get_dataset( ) for s in sample_records ] - else: - samples_total = (dataset.dataset_metadata or {}).get("sample_count", 0) return APIResponse.success_response( data=STTDatasetWithSamples( @@ -160,7 +158,6 @@ def get_dataset( language=dataset.language, object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, - sample_count=samples_total, organization_id=dataset.organization_id, project_id=dataset.project_id, inserted_at=dataset.inserted_at, diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index bf8997d4f..871f4d204 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -245,7 +245,6 @@ def list_stt_datasets( language=dataset.language, object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, - sample_count=(dataset.dataset_metadata or {}).get("sample_count", 0), organization_id=dataset.organization_id, project_id=dataset.project_id, inserted_at=dataset.inserted_at, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 033f38813..c412849bc 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -314,7 +314,6 @@ class STTDatasetPublic(BaseModel): language: str | None object_store_url: str | None dataset_metadata: dict[str, Any] - sample_count: int = Field(0, description="Number of samples in the dataset") organization_id: int project_id: int inserted_at: datetime From 2247faac63d7a5f4ab1341911251eef99902d89c Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 19:49:36 +0530 Subject: [PATCH 21/36] removed deadcode --- backend/app/crud/stt_evaluations/dataset.py | 35 --------------------- 1 file changed, 35 deletions(-) diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index 871f4d204..fec80b198 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -302,38 +302,3 @@ def get_samples_by_dataset_id( samples = session.exec(statement).all() return list(samples), total - - -def update_dataset_metadata( - *, - session: Session, - dataset_id: int, - metadata: dict[str, Any], -) -> EvaluationDataset | None: - """Update dataset metadata. - - Args: - session: Database session - dataset_id: Dataset ID - metadata: Metadata to merge - - Returns: - EvaluationDataset | None: Updated dataset - """ - statement = select(EvaluationDataset).where(EvaluationDataset.id == dataset_id) - dataset = session.exec(statement).one_or_none() - - if not dataset: - return None - - # Merge metadata - current_metadata = dataset.dataset_metadata or {} - current_metadata.update(metadata) - dataset.dataset_metadata = current_metadata - dataset.updated_at = now() - - session.add(dataset) - session.commit() - session.refresh(dataset) - - return dataset From 056612c02f4071e192c3a0a66625d5cafa9a61e6 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 19:53:47 +0530 Subject: [PATCH 22/36] removing more redundant code --- backend/app/crud/stt_evaluations/result.py | 96 +++++----------------- 1 file changed, 22 insertions(+), 74 deletions(-) diff --git a/backend/app/crud/stt_evaluations/result.py b/backend/app/crud/stt_evaluations/result.py index 4c3fd8a9a..f888740c3 100644 --- a/backend/app/crud/stt_evaluations/result.py +++ b/backend/app/crud/stt_evaluations/result.py @@ -10,7 +10,6 @@ from app.models.stt_evaluation import ( STTResult, STTResultStatus, - STTResultPublic, STTSample, STTSamplePublic, STTResultWithSample, @@ -19,48 +18,6 @@ logger = logging.getLogger(__name__) -def create_stt_result( - *, - session: Session, - stt_sample_id: int, - evaluation_run_id: int, - org_id: int, - project_id: int, - provider: str, - status: str = STTResultStatus.PENDING.value, -) -> STTResult: - """Create a single STT result record. - - Args: - session: Database session - stt_sample_id: Sample ID - evaluation_run_id: Run ID - org_id: Organization ID - project_id: Project ID - provider: Provider name - status: Initial status - - Returns: - STTResult: Created result - """ - result = STTResult( - stt_sample_id=stt_sample_id, - evaluation_run_id=evaluation_run_id, - organization_id=org_id, - project_id=project_id, - provider=provider, - status=status, - inserted_at=now(), - updated_at=now(), - ) - - session.add(result) - session.commit() - session.refresh(result) - - return result - - def create_stt_results( *, session: Session, @@ -91,29 +48,26 @@ def create_stt_results( f"provider_count: {len(providers)}" ) - results = [] - - for sample in samples: - for provider in providers: - result = STTResult( - stt_sample_id=sample.id, - evaluation_run_id=evaluation_run_id, - organization_id=org_id, - project_id=project_id, - provider=provider, - status=STTResultStatus.PENDING.value, - inserted_at=now(), - updated_at=now(), - ) - session.add(result) - results.append(result) + timestamp = now() + results = [ + STTResult( + stt_sample_id=sample.id, + evaluation_run_id=evaluation_run_id, + organization_id=org_id, + project_id=project_id, + provider=provider, + status=STTResultStatus.PENDING.value, + inserted_at=timestamp, + updated_at=timestamp, + ) + for sample in samples + for provider in providers + ] + session.add_all(results) + session.flush() session.commit() - # Refresh to get IDs - for result in results: - session.refresh(result) - logger.info( f"[create_stt_results] STT results created | " f"run_id: {evaluation_run_id}, result_count: {len(results)}" @@ -225,8 +179,7 @@ def get_results_by_run_id( transcription=result.transcription, provider=result.provider, status=result.status, - wer=result.wer, - cer=result.cer, + score=result.score, is_correct=result.is_correct, comment=result.comment, provider_metadata=result.provider_metadata, @@ -250,8 +203,7 @@ def update_stt_result( result_id: int, transcription: str | None = None, status: str | None = None, - wer: float | None = None, - cer: float | None = None, + score: dict[str, Any] | None = None, provider_metadata: dict[str, Any] | None = None, error_message: str | None = None, ) -> STTResult | None: @@ -262,8 +214,7 @@ def update_stt_result( result_id: Result ID transcription: Generated transcription status: New status - wer: Word Error Rate - cer: Character Error Rate + score: Evaluation metrics (e.g., wer, cer) provider_metadata: Provider response metadata error_message: Error message if failed @@ -282,11 +233,8 @@ def update_stt_result( if status is not None: result.status = status - if wer is not None: - result.wer = wer - - if cer is not None: - result.cer = cer + if score is not None: + result.score = score if provider_metadata is not None: result.provider_metadata = provider_metadata From 13bb9ccb7189f823499677cd2eac09e91f4640e9 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 20:00:15 +0530 Subject: [PATCH 23/36] clean few more cruds --- backend/app/crud/stt_evaluations/result.py | 28 ++++++++---------- backend/app/crud/stt_evaluations/run.py | 33 ++++++++-------------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/backend/app/crud/stt_evaluations/result.py b/backend/app/crud/stt_evaluations/result.py index f888740c3..0c8d7cae9 100644 --- a/backend/app/crud/stt_evaluations/result.py +++ b/backend/app/crud/stt_evaluations/result.py @@ -129,7 +129,6 @@ def get_results_by_run_id( Returns: tuple[list[STTResultWithSample], int]: Results with samples and total count """ - # Build where clause where_clauses = [ STTResult.evaluation_run_id == run_id, STTResult.organization_id == org_id, @@ -142,11 +141,9 @@ def get_results_by_run_id( if status is not None: where_clauses.append(STTResult.status == status) - # Get total count count_stmt = select(func.count(STTResult.id)).where(*where_clauses) total = session.exec(count_stmt).one() - # Get results with samples statement = ( select(STTResult, STTSample) .join(STTSample, STTResult.stt_sample_id == STTSample.id) @@ -227,20 +224,17 @@ def update_stt_result( if not result: return None - if transcription is not None: - result.transcription = transcription - - if status is not None: - result.status = status - - if score is not None: - result.score = score - - if provider_metadata is not None: - result.provider_metadata = provider_metadata - - if error_message is not None: - result.error_message = error_message + updates = { + "transcription": transcription, + "status": status, + "score": score, + "provider_metadata": provider_metadata, + "error_message": error_message, + } + + for field, value in updates.items(): + if value is not None: + setattr(result, field, value) result.updated_at = now() diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py index 981c99a4f..9d6c6154f 100644 --- a/backend/app/crud/stt_evaluations/run.py +++ b/backend/app/crud/stt_evaluations/run.py @@ -129,7 +129,6 @@ def list_stt_runs( Returns: tuple[list[STTEvaluationRunPublic], int]: Runs and total count """ - # Build base where clause where_clauses = [ EvaluationRun.organization_id == org_id, EvaluationRun.project_id == project_id, @@ -142,11 +141,9 @@ def list_stt_runs( if status is not None: where_clauses.append(EvaluationRun.status == status) - # Get total count count_stmt = select(func.count(EvaluationRun.id)).where(*where_clauses) total = session.exec(count_stmt).one() - # Get runs statement = ( select(EvaluationRun) .where(*where_clauses) @@ -157,7 +154,6 @@ def list_stt_runs( runs = session.exec(statement).all() - # Convert to public models result = [ STTEvaluationRunPublic( id=run.id, @@ -215,23 +211,18 @@ def update_stt_run( if not run: return None - if status is not None: - run.status = status - - if processed_samples is not None: - run.processed_samples = processed_samples - - if score is not None: - run.score = score - - if error_message is not None: - run.error_message = error_message - - if object_store_url is not None: - run.object_store_url = object_store_url - - if batch_job_id is not None: - run.batch_job_id = batch_job_id + updates = { + "status": status, + "processed_samples": processed_samples, + "score": score, + "error_message": error_message, + "object_store_url": object_store_url, + "batch_job_id": batch_job_id, + } + + for field, value in updates.items(): + if value is not None: + setattr(run, field, value) run.updated_at = now() From 7bbf811de0cb92f4c1934681d5593b6ac7ac7cd3 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 20:02:12 +0530 Subject: [PATCH 24/36] more free from dead code --- backend/app/crud/stt_evaluations/run.py | 32 ------------------------- 1 file changed, 32 deletions(-) diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py index 9d6c6154f..d80f6f20c 100644 --- a/backend/app/crud/stt_evaluations/run.py +++ b/backend/app/crud/stt_evaluations/run.py @@ -239,38 +239,6 @@ def update_stt_run( return run -def increment_processed_samples( - *, - session: Session, - run_id: int, - increment: int = 1, -) -> EvaluationRun | None: - """Increment the processed_samples counter for a run. - - Args: - session: Database session - run_id: Run ID - increment: Amount to increment by - - Returns: - EvaluationRun | None: Updated run - """ - statement = select(EvaluationRun).where(EvaluationRun.id == run_id) - run = session.exec(statement).one_or_none() - - if not run: - return None - - run.processed_samples = (run.processed_samples or 0) + increment - run.updated_at = now() - - session.add(run) - session.commit() - session.refresh(run) - - return run - - def get_pending_stt_runs( *, session: Session, From 04e419caa3d89bded791aaac674abb3b018bf8a5 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 20:14:57 +0530 Subject: [PATCH 25/36] cleanup batch request code --- backend/app/crud/stt_evaluations/batch.py | 34 ++++------------------- 1 file changed, 5 insertions(+), 29 deletions(-) diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index 405e3aec5..b84a78066 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -100,36 +100,16 @@ def start_stt_evaluation_batch( storage = get_cloud_storage(session=session, project_id=project_id) # Generate signed URLs for audio files - sample_url_mapping: list[dict[str, Any]] = [] signed_urls: list[str] = [] + sample_keys: list[str] = [] for sample in samples: try: - # Get signed URL for S3 audio file - # Use longer expiry for batch processing (up to 24 hours) signed_url = storage.get_signed_url( sample.object_store_url, expires_in=signed_url_expires_in ) - - # Find the result record for this sample - result_for_sample = next( - (r for r in result_refs if r["stt_sample_id"] == sample.id), - None, - ) - - sample_url_mapping.append( - { - "sample_id": sample.id, - "result_id": result_for_sample["id"] if result_for_sample else None, - "signed_url": signed_url, - } - ) signed_urls.append(signed_url) - - logger.info( - f"[start_stt_evaluation_batch] Generated signed URL | " - f"sample_id: {sample.id}" - ) + sample_keys.append(str(sample.id)) except Exception as e: logger.error( @@ -146,12 +126,9 @@ def start_stt_evaluation_batch( error_message=f"Failed to generate signed URL: {str(e)}", ) - if not sample_url_mapping: + if not signed_urls: raise Exception("Failed to generate signed URLs for any audio files") - # Extract sample IDs as keys for batch request tracking - sample_keys = [str(item["sample_id"]) for item in sample_url_mapping] - # Build batch requests in Gemini JSONL format (with keys embedded) jsonl_data = create_stt_batch_requests( signed_urls=signed_urls, @@ -201,7 +178,6 @@ def start_stt_evaluation_batch( raise Exception(f"Batch submission failed: {str(e)}") # Link batch job to the evaluation run - # Note: sample_file_mapping is no longer needed as we use batch request keys update_stt_run( session=session, run_id=run.id, @@ -212,7 +188,7 @@ def start_stt_evaluation_batch( logger.info( f"[start_stt_evaluation_batch] Batch submission complete | " f"run_id: {run.id}, batch_job_id: {batch_job.id}, " - f"sample_count: {len(sample_url_mapping)}" + f"sample_count: {len(signed_urls)}" ) return { @@ -220,5 +196,5 @@ def start_stt_evaluation_batch( "run_id": run.id, "batch_job_id": batch_job.id, "provider_batch_id": batch_job.provider_batch_id, - "sample_count": len(sample_url_mapping), + "sample_count": len(signed_urls), } From 09deab2b0041ec41501e42b33bfbbf8e3d93a19e Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 20:29:33 +0530 Subject: [PATCH 26/36] cleanup batch --- backend/app/crud/stt_evaluations/batch.py | 40 +++++------------------ 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index b84a78066..b434cb1c4 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -19,30 +19,12 @@ logger = logging.getLogger(__name__) -# Default transcription prompt DEFAULT_TRANSCRIPTION_PROMPT = ( "Generate a verbatim transcript of the speech in this audio file. " "Return only the transcription text without any formatting, timestamps, or metadata." ) -# Provider name to Gemini model mapping -PROVIDER_MODEL_MAPPING: dict[str, str] = { - "gemini-2.5-pro": "models/gemini-2.5-pro", - "gemini-2.5-flash": "models/gemini-2.5-flash", - "gemini-2.0-flash": "models/gemini-2.0-flash", -} - - -def _get_model_for_provider(provider: str) -> str: - """Map provider name to Gemini model. - - Args: - provider: Provider name - - Returns: - str: Gemini model name - """ - return PROVIDER_MODEL_MAPPING.get(provider, f"models/{provider}") +DEFAULT_MODEL = "gemini-2.5-pro" def start_stt_evaluation_batch( @@ -66,8 +48,7 @@ def start_stt_evaluation_batch( 6. Updates run status to "processing" Note: Sample IDs are passed as keys in the batch request and stored in - batch_job.config["request_keys"]. This allows direct mapping of results - without storing sample_file_mapping in run.score. + batch_job.config["request_keys"]. Args: session: Database session @@ -129,20 +110,16 @@ def start_stt_evaluation_batch( if not signed_urls: raise Exception("Failed to generate signed URLs for any audio files") - # Build batch requests in Gemini JSONL format (with keys embedded) jsonl_data = create_stt_batch_requests( signed_urls=signed_urls, prompt=DEFAULT_TRANSCRIPTION_PROMPT, keys=sample_keys, ) - # Use first provider (STT evaluations use one provider per run) - providers = run.providers or ["gemini-2.5-pro"] - provider = providers[0] - model = _get_model_for_provider(provider) + model = (run.providers or [DEFAULT_MODEL])[0] + model_path = f"models/{model}" - # Create batch job using the standard batch operations - batch_provider = GeminiBatchProvider(client=gemini_client.client, model=model) + batch_provider = GeminiBatchProvider(client=gemini_client.client, model=model_path) try: batch_job = start_batch_job( @@ -153,19 +130,18 @@ def start_stt_evaluation_batch( organization_id=org_id, project_id=project_id, jsonl_data=jsonl_data, - config={"model": provider}, + config={"model": model}, ) logger.info( f"[start_stt_evaluation_batch] Batch job created | " - f"run_id: {run.id}, batch_job_id: {batch_job.id}, " - f"provider_batch_id: {batch_job.provider_batch_id}" + f"run_id: {run.id}, batch_job_id: {batch_job.id}" ) except Exception as e: logger.error( f"[start_stt_evaluation_batch] Failed to submit batch | " - f"provider: {provider}, error: {str(e)}" + f"model: {model}, error: {str(e)}" ) # Update all results as failed for ref in result_refs: From f6bf0c213040d86edcd171bdc2b53efe9b8310eb Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 21:47:45 +0530 Subject: [PATCH 27/36] got rid of processed_samples as well --- .../versions/043_add_stt_evaluation_tables.py | 14 +------------- .../app/api/routes/stt_evaluations/evaluation.py | 2 -- backend/app/crud/stt_evaluations/cron.py | 1 - backend/app/crud/stt_evaluations/run.py | 9 +-------- backend/app/models/evaluation.py | 5 ----- backend/app/models/stt_evaluation.py | 1 - 6 files changed, 2 insertions(+), 30 deletions(-) diff --git a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py index c547302eb..9cff5e759 100644 --- a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -40,7 +40,7 @@ def upgrade(): ), ) - # Add type, language, providers, and processed_samples columns to evaluation_run table + # Add type, language, and providers columns to evaluation_run table op.add_column( "evaluation_run", sa.Column( @@ -69,17 +69,6 @@ def upgrade(): comment="List of STT/TTS providers used (e.g., ['gemini-2.5-pro'])", ), ) - op.add_column( - "evaluation_run", - sa.Column( - "processed_samples", - sa.Integer(), - nullable=False, - server_default=sa.text("0"), - comment="Number of samples processed so far", - ), - ) - # Create stt_sample table op.create_table( "stt_sample", @@ -334,7 +323,6 @@ def downgrade(): op.drop_table("stt_sample") # Remove columns from evaluation_run table - op.drop_column("evaluation_run", "processed_samples") op.drop_column("evaluation_run", "providers") op.drop_column("evaluation_run", "language") op.drop_column("evaluation_run", "type") diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index 5b54d0082..f04b0dfa1 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -148,7 +148,6 @@ def start_stt_evaluation( dataset_id=run.dataset_id, status=run.status, total_items=run.total_items, - processed_samples=run.processed_samples, score=run.score, error_message=run.error_message, organization_id=run.organization_id, @@ -245,7 +244,6 @@ def get_stt_evaluation_run( dataset_id=run.dataset_id, status=run.status, total_items=run.total_items, - processed_samples=run.processed_samples, score=run.score, error_message=run.error_message, organization_id=run.organization_id, diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index 226cf1af7..923fcb322 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -300,7 +300,6 @@ async def process_completed_stt_batch( session=session, run_id=run.id, status=final_status, - processed_samples=completed + failed, error_message=error_message, ) diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py index d80f6f20c..0d3ef4fbb 100644 --- a/backend/app/crud/stt_evaluations/run.py +++ b/backend/app/crud/stt_evaluations/run.py @@ -58,7 +58,6 @@ def create_stt_run( providers=providers, status="pending", total_items=total_items, - processed_samples=0, organization_id=org_id, project_id=project_id, inserted_at=now(), @@ -165,7 +164,6 @@ def list_stt_runs( dataset_id=run.dataset_id, status=run.status, total_items=run.total_items, - processed_samples=run.processed_samples, score=run.score, error_message=run.error_message, organization_id=run.organization_id, @@ -184,7 +182,6 @@ def update_stt_run( session: Session, run_id: int, status: str | None = None, - processed_samples: int | None = None, score: dict[str, Any] | None = None, error_message: str | None = None, object_store_url: str | None = None, @@ -196,7 +193,6 @@ def update_stt_run( session: Database session run_id: Run ID status: New status - processed_samples: Number of processed samples score: Score data error_message: Error message object_store_url: URL to stored results @@ -213,7 +209,6 @@ def update_stt_run( updates = { "status": status, - "processed_samples": processed_samples, "score": score, "error_message": error_message, "object_store_url": object_store_url, @@ -231,9 +226,7 @@ def update_stt_run( session.refresh(run) logger.info( - f"[update_stt_run] STT run updated | " - f"run_id: {run_id}, status: {run.status}, " - f"processed_samples: {run.processed_samples}" + f"[update_stt_run] STT run updated | run_id: {run_id}, status: {run.status}" ) return run diff --git a/backend/app/models/evaluation.py b/backend/app/models/evaluation.py index 8b24fe9e4..1397fa435 100644 --- a/backend/app/models/evaluation.py +++ b/backend/app/models/evaluation.py @@ -226,11 +226,6 @@ class EvaluationRun(SQLModel, table=True): ), description="List of STT/TTS providers used", ) - processed_samples: int = SQLField( - default=0, - description="Number of samples processed so far", - sa_column_kwargs={"comment": "Number of samples processed so far"}, - ) config_id: UUID = SQLField( foreign_key="config.id", diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index c412849bc..a46c2c9a5 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -350,7 +350,6 @@ class STTEvaluationRunPublic(BaseModel): dataset_id: int status: str total_items: int - processed_samples: int score: dict[str, Any] | None error_message: str | None organization_id: int From d20084b5921273ea381ad8819ceafba35bc65a32 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Tue, 3 Feb 2026 21:58:47 +0530 Subject: [PATCH 28/36] cleanup provider_metadata from results --- .../alembic/versions/043_add_stt_evaluation_tables.py | 7 ------- backend/app/api/routes/stt_evaluations/result.py | 8 ++------ backend/app/crud/stt_evaluations/result.py | 4 ---- backend/app/models/stt_evaluation.py | 11 ----------- 4 files changed, 2 insertions(+), 28 deletions(-) diff --git a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py index 9cff5e759..feb96080b 100644 --- a/backend/app/alembic/versions/043_add_stt_evaluation_tables.py +++ b/backend/app/alembic/versions/043_add_stt_evaluation_tables.py @@ -210,13 +210,6 @@ def upgrade(): nullable=True, comment="Human feedback comment", ), - sa.Column( - "provider_metadata", - postgresql.JSONB(astext_type=sa.Text()), - nullable=True, - server_default=sa.text("'{}'::jsonb"), - comment="Provider-specific response metadata (tokens, latency, etc.)", - ), sa.Column( "error_message", sa.Text(), diff --git a/backend/app/api/routes/stt_evaluations/result.py b/backend/app/api/routes/stt_evaluations/result.py index 5ba1d0597..b2bc48bb0 100644 --- a/backend/app/api/routes/stt_evaluations/result.py +++ b/backend/app/api/routes/stt_evaluations/result.py @@ -67,11 +67,9 @@ def update_result_feedback( transcription=result.transcription, provider=result.provider, status=result.status, - wer=result.wer, - cer=result.cer, + score=result.score, is_correct=result.is_correct, comment=result.comment, - provider_metadata=result.provider_metadata, error_message=result.error_message, stt_sample_id=result.stt_sample_id, evaluation_run_id=result.evaluation_run_id, @@ -112,11 +110,9 @@ def get_result( transcription=result.transcription, provider=result.provider, status=result.status, - wer=result.wer, - cer=result.cer, + score=result.score, is_correct=result.is_correct, comment=result.comment, - provider_metadata=result.provider_metadata, error_message=result.error_message, stt_sample_id=result.stt_sample_id, evaluation_run_id=result.evaluation_run_id, diff --git a/backend/app/crud/stt_evaluations/result.py b/backend/app/crud/stt_evaluations/result.py index 0c8d7cae9..3d6e862ae 100644 --- a/backend/app/crud/stt_evaluations/result.py +++ b/backend/app/crud/stt_evaluations/result.py @@ -179,7 +179,6 @@ def get_results_by_run_id( score=result.score, is_correct=result.is_correct, comment=result.comment, - provider_metadata=result.provider_metadata, error_message=result.error_message, stt_sample_id=result.stt_sample_id, evaluation_run_id=result.evaluation_run_id, @@ -201,7 +200,6 @@ def update_stt_result( transcription: str | None = None, status: str | None = None, score: dict[str, Any] | None = None, - provider_metadata: dict[str, Any] | None = None, error_message: str | None = None, ) -> STTResult | None: """Update an STT result with transcription data. @@ -212,7 +210,6 @@ def update_stt_result( transcription: Generated transcription status: New status score: Evaluation metrics (e.g., wer, cer) - provider_metadata: Provider response metadata error_message: Error message if failed Returns: @@ -228,7 +225,6 @@ def update_stt_result( "transcription": transcription, "status": status, "score": score, - "provider_metadata": provider_metadata, "error_message": error_message, } diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index a46c2c9a5..aed5f45d2 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -174,16 +174,6 @@ class STTResult(SQLModel, table=True): description="Human feedback comment", ) - provider_metadata: dict[str, Any] | None = SQLField( - default_factory=dict, - sa_column=Column( - JSONB, - nullable=True, - comment="Provider-specific response metadata (tokens, latency, etc.)", - ), - description="Provider-specific response metadata", - ) - error_message: str | None = SQLField( default=None, sa_column=Column( @@ -270,7 +260,6 @@ class STTResultPublic(BaseModel): score: dict[str, Any] | None is_correct: bool | None comment: str | None - provider_metadata: dict[str, Any] | None error_message: str | None stt_sample_id: int evaluation_run_id: int From 4afdd2daabb2118d1f5d8c71310fa544186730b6 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 10:15:37 +0530 Subject: [PATCH 29/36] cleanup optimize results --- .../api/routes/stt_evaluations/evaluation.py | 11 ++---- backend/app/crud/evaluations/cron.py | 2 +- backend/app/crud/stt_evaluations/__init__.py | 4 -- backend/app/crud/stt_evaluations/batch.py | 38 ++++++------------- 4 files changed, 16 insertions(+), 39 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index f04b0dfa1..da7349867 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -9,6 +9,7 @@ from app.crud.stt_evaluations import ( create_stt_run, create_stt_results, + get_results_by_run_id, get_samples_by_dataset_id, get_stt_dataset_by_id, get_stt_run_by_id, @@ -16,7 +17,6 @@ start_stt_evaluation_batch, update_stt_run, ) -from app.crud.stt_evaluations.result import get_results_by_run_id from app.models.stt_evaluation import ( STTEvaluationRunCreate, STTEvaluationRunPublic, @@ -96,19 +96,14 @@ def start_stt_evaluation( providers=run_create.providers, ) - # Extract result data for batch processing - result_refs = [ - {"id": r.id, "stt_sample_id": r.stt_sample_id, "provider": r.provider} - for r in results - ] + sample_to_result = {r.stt_sample_id: r.id for r in results} - # Submit batch synchronously try: batch_result = start_stt_evaluation_batch( session=_session, run=run, samples=samples, - result_refs=result_refs, + sample_to_result=sample_to_result, org_id=auth_context.organization_.id, project_id=auth_context.project_.id, ) diff --git a/backend/app/crud/evaluations/cron.py b/backend/app/crud/evaluations/cron.py index 72b22ec69..fb9f7a66b 100644 --- a/backend/app/crud/evaluations/cron.py +++ b/backend/app/crud/evaluations/cron.py @@ -12,7 +12,7 @@ from sqlmodel import Session, select from app.crud.evaluations.processing import poll_all_pending_evaluations -from app.crud.stt_evaluations.cron import poll_all_pending_stt_evaluations +from app.crud.stt_evaluations import poll_all_pending_stt_evaluations from app.models import Organization logger = logging.getLogger(__name__) diff --git a/backend/app/crud/stt_evaluations/__init__.py b/backend/app/crud/stt_evaluations/__init__.py index 2202ca02a..71f0b3f3f 100644 --- a/backend/app/crud/stt_evaluations/__init__.py +++ b/backend/app/crud/stt_evaluations/__init__.py @@ -19,9 +19,7 @@ create_stt_results, get_stt_result_by_id, get_results_by_run_id, - update_stt_result, update_human_feedback, - count_results_by_status, ) __all__ = [ @@ -44,7 +42,5 @@ "create_stt_results", "get_stt_result_by_id", "get_results_by_run_id", - "update_stt_result", "update_human_feedback", - "count_results_by_status", ] diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index b434cb1c4..78e3d6af1 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -32,32 +32,21 @@ def start_stt_evaluation_batch( session: Session, run: EvaluationRun, samples: list[STTSample], - result_refs: list[dict[str, Any]], + sample_to_result: dict[int, int], org_id: int, project_id: int, signed_url_expires_in: int = 86400, ) -> dict[str, Any]: """Generate signed URLs and submit Gemini batch job for STT evaluation. - This function runs synchronously during the API request: - 1. Initializes GeminiClient - 2. Generates signed URLs for audio files (valid for batch processing window) - 3. Builds batch requests using signed URLs with sample_id as key for tracking - 4. Submits batch job via start_batch_job (creates BatchJob record with request_keys) - 5. Links batch_job_id to the evaluation run - 6. Updates run status to "processing" - - Note: Sample IDs are passed as keys in the batch request and stored in - batch_job.config["request_keys"]. - Args: session: Database session run: The evaluation run record samples: List of STT samples to process - result_refs: List of result reference dicts with id, stt_sample_id, provider + sample_to_result: Mapping of sample_id -> result_id for error handling org_id: Organization ID project_id: Project ID - signed_url_expires_in: Signed URL expiry in seconds (default: 24 hours for batch) + signed_url_expires_in: Signed URL expiry in seconds (default: 24 hours) Returns: dict: Result with batch job information @@ -97,15 +86,13 @@ def start_stt_evaluation_batch( f"[start_stt_evaluation_batch] Failed to generate signed URL | " f"sample_id: {sample.id}, error: {str(e)}" ) - # Mark result as failed - for ref in result_refs: - if ref["stt_sample_id"] == sample.id: - update_stt_result( - session=session, - result_id=ref["id"], - status=STTResultStatus.FAILED.value, - error_message=f"Failed to generate signed URL: {str(e)}", - ) + if sample.id in sample_to_result: + update_stt_result( + session=session, + result_id=sample_to_result[sample.id], + status=STTResultStatus.FAILED.value, + error_message=f"Failed to generate signed URL: {str(e)}", + ) if not signed_urls: raise Exception("Failed to generate signed URLs for any audio files") @@ -143,11 +130,10 @@ def start_stt_evaluation_batch( f"[start_stt_evaluation_batch] Failed to submit batch | " f"model: {model}, error: {str(e)}" ) - # Update all results as failed - for ref in result_refs: + for result_id in sample_to_result.values(): update_stt_result( session=session, - result_id=ref["id"], + result_id=result_id, status=STTResultStatus.FAILED.value, error_message=f"Batch submission failed: {str(e)}", ) From 3e62a987a40620e16d73cf1835f42640e1a17083 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 10:27:07 +0530 Subject: [PATCH 30/36] cleanup queries --- .../app/api/routes/stt_evaluations/dataset.py | 2 +- .../api/routes/stt_evaluations/evaluation.py | 4 ++-- backend/app/crud/stt_evaluations/dataset.py | 17 +++-------------- backend/app/crud/stt_evaluations/run.py | 14 +++++--------- 4 files changed, 11 insertions(+), 26 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index 791022a63..cb6baff0f 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -124,7 +124,7 @@ def get_dataset( samples_total = (dataset.dataset_metadata or {}).get("sample_count", 0) if include_samples: - sample_records, samples_total = get_samples_by_dataset_id( + sample_records = get_samples_by_dataset_id( session=_session, dataset_id=dataset_id, org_id=auth_context.organization_.id, diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index da7349867..2ecdcacee 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -78,12 +78,12 @@ def start_stt_evaluation( ) # Get samples for the dataset - samples, _ = get_samples_by_dataset_id( + samples = get_samples_by_dataset_id( session=_session, dataset_id=run_create.dataset_id, org_id=auth_context.organization_.id, project_id=auth_context.project_.id, - limit=10000, # Get all samples + limit=10000, ) # Create result records for each sample and provider diff --git a/backend/app/crud/stt_evaluations/dataset.py b/backend/app/crud/stt_evaluations/dataset.py index fec80b198..d5e8e32dc 100644 --- a/backend/app/crud/stt_evaluations/dataset.py +++ b/backend/app/crud/stt_evaluations/dataset.py @@ -264,7 +264,7 @@ def get_samples_by_dataset_id( project_id: int, limit: int = 100, offset: int = 0, -) -> tuple[list[STTSample], int]: +) -> list[STTSample]: """Get samples for a dataset. Args: @@ -276,17 +276,8 @@ def get_samples_by_dataset_id( offset: Number of results to skip Returns: - tuple[list[STTSample], int]: Samples and total count + list[STTSample]: Samples """ - # Get total count - count_stmt = select(func.count(STTSample.id)).where( - STTSample.dataset_id == dataset_id, - STTSample.organization_id == org_id, - STTSample.project_id == project_id, - ) - total = session.exec(count_stmt).one() - - # Get samples statement = ( select(STTSample) .where( @@ -299,6 +290,4 @@ def get_samples_by_dataset_id( .limit(limit) ) - samples = session.exec(statement).all() - - return list(samples), total + return list(session.exec(statement).all()) diff --git a/backend/app/crud/stt_evaluations/run.py b/backend/app/crud/stt_evaluations/run.py index 0d3ef4fbb..76e70537c 100644 --- a/backend/app/crud/stt_evaluations/run.py +++ b/backend/app/crud/stt_evaluations/run.py @@ -235,7 +235,7 @@ def update_stt_run( def get_pending_stt_runs( *, session: Session, - org_id: int | None = None, + org_id: int, ) -> list[EvaluationRun]: """Get all pending STT evaluation runs that are ready for polling. @@ -244,20 +244,16 @@ def get_pending_stt_runs( Args: session: Database session - org_id: Optional filter by organization + org_id: Organization ID Returns: list[EvaluationRun]: Pending runs ready for polling """ - where_clauses = [ + statement = select(EvaluationRun).where( EvaluationRun.type == EvaluationType.STT.value, EvaluationRun.status == "processing", EvaluationRun.batch_job_id.is_not(None), - ] - - if org_id is not None: - where_clauses.append(EvaluationRun.organization_id == org_id) - - statement = select(EvaluationRun).where(*where_clauses) + EvaluationRun.organization_id == org_id, + ) return list(session.exec(statement).all()) From 63de270ef11cd933225e68bd4f7a73267b50dd9c Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 10:33:18 +0530 Subject: [PATCH 31/36] cleanup leftovers --- backend/app/models/stt_evaluation.py | 19 +- .../services/stt_evaluations/gemini/files.py | 271 ------------------ 2 files changed, 2 insertions(+), 288 deletions(-) delete mode 100644 backend/app/services/stt_evaluations/gemini/files.py diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index aed5f45d2..e6837c1f9 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -2,21 +2,16 @@ from datetime import datetime from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import Any from pydantic import BaseModel, Field from sqlalchemy import Column, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField -from sqlmodel import Relationship, SQLModel +from sqlmodel import SQLModel from app.core.util import now -if TYPE_CHECKING: - from .evaluation import EvaluationDataset, EvaluationRun - from .organization import Organization - from .project import Project - class EvaluationType(str, Enum): """Type of evaluation dataset/run.""" @@ -107,11 +102,6 @@ class STTSample(SQLModel, table=True): sa_column_kwargs={"comment": "Timestamp when the sample was last updated"}, ) - dataset: "EvaluationDataset" = Relationship() - organization: "Organization" = Relationship() - project: "Project" = Relationship() - results: list["STTResult"] = Relationship(back_populates="sample") - class STTResult(SQLModel, table=True): """Database table for STT transcription results.""" @@ -220,11 +210,6 @@ class STTResult(SQLModel, table=True): sa_column_kwargs={"comment": "Timestamp when the result was last updated"}, ) - sample: "STTSample" = Relationship(back_populates="results") - evaluation_run: "EvaluationRun" = Relationship() - organization: "Organization" = Relationship() - project: "Project" = Relationship() - class STTSampleCreate(BaseModel): """Request model for creating an STT sample.""" diff --git a/backend/app/services/stt_evaluations/gemini/files.py b/backend/app/services/stt_evaluations/gemini/files.py deleted file mode 100644 index dadbef5ab..000000000 --- a/backend/app/services/stt_evaluations/gemini/files.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Google Files API operations for STT evaluation.""" - -import logging -import os -import tempfile -from typing import BinaryIO - -import requests -from google import genai -from google.genai import types - -from app.services.stt_evaluations.constants import EXTENSION_TO_MIME - -logger = logging.getLogger(__name__) - - -class GeminiFilesError(Exception): - """Exception raised for Gemini Files API errors.""" - - pass - - -def get_mime_type(file_path: str) -> str: - """Get MIME type based on file extension. - - Args: - file_path: Path or URL of the audio file - - Returns: - str: MIME type string - """ - extension = file_path.lower().split(".")[-1] - return EXTENSION_TO_MIME.get(extension, "audio/mpeg") - - -def get_extension(file_path: str) -> str: - """Get file extension from path or URL. - - Args: - file_path: Path or URL of the audio file - - Returns: - str: File extension (lowercase, without dot) - """ - return file_path.lower().split(".")[-1] - - -class GeminiFilesManager: - """Manage file uploads to Google Files API.""" - - # Base URL for Gemini Files API - FILES_API_BASE = "https://generativelanguage.googleapis.com/v1beta" - - def __init__(self, client: genai.Client): - """Initialize files manager. - - Args: - client: Gemini client instance - """ - self._client = client - - def upload_from_bytes( - self, - content: bytes, - filename: str, - mime_type: str | None = None, - ) -> str: - """Upload audio content directly to Google Files API. - - Args: - content: Audio file content as bytes - filename: Display name for the file - mime_type: MIME type (auto-detected if not provided) - - Returns: - str: Full Google Files API URI (HTTPS format for batch API) - - Raises: - GeminiFilesError: If upload fails - """ - if mime_type is None: - mime_type = get_mime_type(filename) - - logger.info( - f"[upload_from_bytes] Uploading file to Google Files API | " - f"filename: {filename}, mime_type: {mime_type}, size_bytes: {len(content)}" - ) - - try: - # Write content to a temporary file - extension = get_extension(filename) - with tempfile.NamedTemporaryFile( - suffix=f".{extension}", delete=False - ) as tmp_file: - tmp_file.write(content) - tmp_path = tmp_file.name - - try: - # Upload to Google Files API - uploaded_file = self._client.files.upload( - file=tmp_path, - config=types.UploadFileConfig( - display_name=filename, - mime_type=mime_type, - ), - ) - - # CRITICAL: Return the full HTTPS URI for batch API compatibility - # The batch API requires full URLs, not short form like "files/xxx" - file_uri = f"{self.FILES_API_BASE}/{uploaded_file.name}" - - logger.info( - f"[upload_from_bytes] File uploaded successfully | " - f"filename: {filename}, file_uri: {file_uri}" - ) - - return file_uri - - finally: - # Clean up temporary file - os.unlink(tmp_path) - - except Exception as e: - logger.error( - f"[upload_from_bytes] Failed to upload file | " - f"filename: {filename}, error: {str(e)}" - ) - raise GeminiFilesError(f"Failed to upload file {filename}: {str(e)}") from e - - def upload_from_url( - self, - signed_url: str, - filename: str, - mime_type: str | None = None, - ) -> str: - """Upload audio from a signed URL to Google Files API. - - Args: - signed_url: Signed URL to download the audio file - filename: Display name for the file - mime_type: MIME type (auto-detected if not provided) - - Returns: - str: Full Google Files API URI (HTTPS format for batch API) - - Raises: - GeminiFilesError: If download or upload fails - """ - logger.info( - f"[upload_from_url] Downloading file from URL | filename: {filename}" - ) - - try: - # Download the file from the signed URL - response = requests.get(signed_url, timeout=300) # 5 minute timeout - response.raise_for_status() - - # Upload to Google Files API - return self.upload_from_bytes( - content=response.content, - filename=filename, - mime_type=mime_type, - ) - - except requests.RequestException as e: - logger.error( - f"[upload_from_url] Failed to download file | " - f"filename: {filename}, error: {str(e)}" - ) - raise GeminiFilesError( - f"Failed to download file {filename}: {str(e)}" - ) from e - - def upload_from_stream( - self, - stream: BinaryIO, - filename: str, - mime_type: str | None = None, - ) -> str: - """Upload audio from a stream to Google Files API. - - Args: - stream: Binary stream (e.g., from S3 StreamingBody) - filename: Display name for the file - mime_type: MIME type (auto-detected if not provided) - - Returns: - str: Full Google Files API URI (HTTPS format for batch API) - - Raises: - GeminiFilesError: If upload fails - """ - try: - content = stream.read() - return self.upload_from_bytes( - content=content, - filename=filename, - mime_type=mime_type, - ) - except Exception as e: - logger.error( - f"[upload_from_stream] Failed to read stream | " - f"filename: {filename}, error: {str(e)}" - ) - raise GeminiFilesError( - f"Failed to read stream for {filename}: {str(e)}" - ) from e - - def delete_file(self, file_uri: str) -> bool: - """Delete a file from Google Files API. - - Args: - file_uri: Full Google Files API URI or short form (files/xxx) - - Returns: - bool: True if deletion was successful - """ - # Extract the file name from the URI - if file_uri.startswith(self.FILES_API_BASE): - file_name = file_uri.replace(f"{self.FILES_API_BASE}/", "") - else: - file_name = file_uri - - logger.info(f"[delete_file] Deleting file | file_name: {file_name}") - - try: - self._client.files.delete(name=file_name) - logger.info( - f"[delete_file] File deleted successfully | file_name: {file_name}" - ) - return True - except Exception as e: - logger.warning( - f"[delete_file] Failed to delete file | " - f"file_name: {file_name}, error: {str(e)}" - ) - return False - - def get_file_status(self, file_uri: str) -> dict: - """Get the status of a file in Google Files API. - - Args: - file_uri: Full Google Files API URI or short form (files/xxx) - - Returns: - dict: File metadata including state - """ - # Extract the file name from the URI - if file_uri.startswith(self.FILES_API_BASE): - file_name = file_uri.replace(f"{self.FILES_API_BASE}/", "") - else: - file_name = file_uri - - try: - file_info = self._client.files.get(name=file_name) - return { - "name": file_info.name, - "display_name": file_info.display_name, - "mime_type": file_info.mime_type, - "size_bytes": file_info.size_bytes, - "state": file_info.state.name if file_info.state else None, - "uri": file_info.uri, - } - except Exception as e: - logger.error( - f"[get_file_status] Failed to get file status | " - f"file_name: {file_name}, error: {str(e)}" - ) - raise GeminiFilesError( - f"Failed to get file status for {file_name}: {str(e)}" - ) from e From c95c0444c7a18de6baf83ed8386f5f6882a0d34c Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 11:19:45 +0530 Subject: [PATCH 32/36] added validation for provider --- backend/app/models/stt_evaluation.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index e6837c1f9..d2ded0ee3 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from sqlalchemy import Column, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField @@ -12,6 +12,9 @@ from app.core.util import now +# Supported STT providers for evaluation +SUPPORTED_STT_PROVIDERS = ["gemini-2.5-pro"] + class EvaluationType(str, Enum): """Type of evaluation dataset/run.""" @@ -311,6 +314,20 @@ class STTEvaluationRunCreate(BaseModel): ) language: str | None = Field(None, description="Override language for all samples") + @field_validator("providers") + @classmethod + def validate_providers(cls, v: list[str]) -> list[str]: + """Validate that all providers are supported.""" + if not v: + raise ValueError("At least one provider must be specified") + unsupported = [p for p in v if p not in SUPPORTED_STT_PROVIDERS] + if unsupported: + raise ValueError( + f"Unsupported provider(s): {', '.join(unsupported)}. " + f"Supported providers are: {', '.join(SUPPORTED_STT_PROVIDERS)}" + ) + return v + class STTEvaluationRunPublic(BaseModel): """Public model for STT evaluation runs.""" From 9aa6858e2cdfb6e6b555c2ab4eef178e5933d4d8 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 12:11:24 +0530 Subject: [PATCH 33/36] updated test suite --- .../app/api/routes/stt_evaluations/dataset.py | 1 + backend/app/models/stt_evaluation.py | 1 + .../tests/api/routes/test_stt_evaluation.py | 5 ++ backend/app/tests/core/batch/test_gemini.py | 53 ++++++++++--------- backend/app/tests/crud/test_credentials.py | 2 +- 5 files changed, 35 insertions(+), 27 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index cb6baff0f..6f0b47210 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -158,6 +158,7 @@ def get_dataset( language=dataset.language, object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, + sample_count=samples_total, organization_id=dataset.organization_id, project_id=dataset.project_id, inserted_at=dataset.inserted_at, diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index d2ded0ee3..863d8cf3f 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -291,6 +291,7 @@ class STTDatasetPublic(BaseModel): language: str | None object_store_url: str | None dataset_metadata: dict[str, Any] + sample_count: int = 0 organization_id: int project_id: int inserted_at: datetime diff --git a/backend/app/tests/api/routes/test_stt_evaluation.py b/backend/app/tests/api/routes/test_stt_evaluation.py index 9352e0726..3251c9aa1 100644 --- a/backend/app/tests/api/routes/test_stt_evaluation.py +++ b/backend/app/tests/api/routes/test_stt_evaluation.py @@ -379,6 +379,7 @@ def test_get_stt_dataset_without_samples( user_api_key: TestAuthContext, ) -> None: """Test getting an STT dataset without including samples.""" + # Create dataset with sample_count in metadata set correctly dataset = create_test_stt_dataset( db=db, organization_id=user_api_key.organization_id, @@ -391,6 +392,10 @@ def test_get_stt_dataset_without_samples( organization_id=user_api_key.organization_id, project_id=user_api_key.project_id, ) + # Update dataset metadata to reflect the sample count + dataset.dataset_metadata = {"sample_count": 1, "has_ground_truth_count": 0} + db.add(dataset) + db.commit() response = client.get( f"/api/v1/evaluations/stt/datasets/{dataset.id}", diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py index 026232ab5..b00442f43 100644 --- a/backend/app/tests/core/batch/test_gemini.py +++ b/backend/app/tests/core/batch/test_gemini.py @@ -57,14 +57,14 @@ def test_create_batch_success(self, provider, mock_genai_client): } # Mock file upload - mock_genai_client.files.upload.return_value = MagicMock( - name="files/uploaded-123" - ) + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-123" + mock_genai_client.files.upload.return_value = mock_uploaded_file # Mock batch creation mock_batch_job = MagicMock() mock_batch_job.name = "batches/batch-xyz789" - mock_batch_job.state = MagicMock(name="JOB_STATE_PENDING") + mock_batch_job.state.name = "JOB_STATE_PENDING" mock_genai_client.batches.create.return_value = mock_batch_job with patch("tempfile.NamedTemporaryFile"): @@ -81,13 +81,13 @@ def test_create_batch_with_default_config(self, provider, mock_genai_client): jsonl_data = [{"key": "req-1", "request": {}}] config = {} - mock_genai_client.files.upload.return_value = MagicMock( - name="files/uploaded-456" - ) + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-456" + mock_genai_client.files.upload.return_value = mock_uploaded_file mock_batch_job = MagicMock() mock_batch_job.name = "batches/batch-123" - mock_batch_job.state = MagicMock(name="JOB_STATE_PENDING") + mock_batch_job.state.name = "JOB_STATE_PENDING" mock_genai_client.batches.create.return_value = mock_batch_job with patch("tempfile.NamedTemporaryFile"): @@ -116,9 +116,9 @@ def test_create_batch_batch_creation_error(self, provider, mock_genai_client): jsonl_data = [{"key": "req-1", "request": {}}] config = {"display_name": "test"} - mock_genai_client.files.upload.return_value = MagicMock( - name="files/uploaded-123" - ) + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/uploaded-123" + mock_genai_client.files.upload.return_value = mock_uploaded_file mock_genai_client.batches.create.side_effect = Exception( "Batch creation failed" ) @@ -135,7 +135,7 @@ def test_get_batch_status_pending(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_PENDING") + mock_batch_job.state.name = "JOB_STATE_PENDING" mock_genai_client.batches.get.return_value = mock_batch_job result = provider.get_batch_status(batch_id) @@ -150,7 +150,7 @@ def test_get_batch_status_succeeded(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" mock_genai_client.batches.get.return_value = mock_batch_job result = provider.get_batch_status(batch_id) @@ -163,7 +163,7 @@ def test_get_batch_status_failed(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_FAILED") + mock_batch_job.state.name = "JOB_STATE_FAILED" mock_genai_client.batches.get.return_value = mock_batch_job result = provider.get_batch_status(batch_id) @@ -177,7 +177,7 @@ def test_get_batch_status_cancelled(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_CANCELLED") + mock_batch_job.state.name = "JOB_STATE_CANCELLED" mock_genai_client.batches.get.return_value = mock_batch_job result = provider.get_batch_status(batch_id) @@ -190,7 +190,7 @@ def test_get_batch_status_expired(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_EXPIRED") + mock_batch_job.state.name = "JOB_STATE_EXPIRED" mock_genai_client.batches.get.return_value = mock_batch_job result = provider.get_batch_status(batch_id) @@ -214,7 +214,7 @@ def test_download_batch_results_success(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" mock_batch_job.dest = MagicMock() mock_batch_job.dest.file_name = "files/output-123" mock_genai_client.batches.get.return_value = mock_batch_job @@ -244,7 +244,7 @@ def test_download_batch_results_with_direct_text_response( batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" mock_batch_job.dest = MagicMock() mock_batch_job.dest.file_name = "files/output-123" mock_genai_client.batches.get.return_value = mock_batch_job @@ -262,7 +262,7 @@ def test_download_batch_results_with_errors(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" mock_batch_job.dest = MagicMock() mock_batch_job.dest.file_name = "files/output-123" mock_genai_client.batches.get.return_value = mock_batch_job @@ -290,7 +290,7 @@ def test_download_batch_results_batch_not_complete( batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_RUNNING") + mock_batch_job.state.name = "JOB_STATE_RUNNING" mock_genai_client.batches.get.return_value = mock_batch_job with pytest.raises(ValueError) as exc_info: @@ -303,7 +303,7 @@ def test_download_batch_results_malformed_json(self, provider, mock_genai_client batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" mock_batch_job.dest = MagicMock() mock_batch_job.dest.file_name = "files/output-123" mock_genai_client.batches.get.return_value = mock_batch_job @@ -327,7 +327,7 @@ def test_download_batch_results_no_dest_file(self, provider, mock_genai_client): batch_id = "batches/batch-xyz789" mock_batch_job = MagicMock() - mock_batch_job.state = MagicMock(name="JOB_STATE_SUCCEEDED") + mock_batch_job.state.name = "JOB_STATE_SUCCEEDED" mock_batch_job.dest = None mock_genai_client.batches.get.return_value = mock_batch_job @@ -476,10 +476,11 @@ def test_create_requests_mime_type_detection(self): requests[1]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] == "audio/x-wav" ) - assert ( - requests[2]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] - == "audio/mp4" - ) + # .m4a can return different MIME types depending on the system + m4a_mime = requests[2]["request"]["contents"][0]["parts"][1]["file_data"][ + "mime_type" + ] + assert m4a_mime in ("audio/mp4", "audio/mp4a-latm", "audio/x-m4a") def test_create_requests_key_length_mismatch(self): """Test that mismatched keys and URLs raise error.""" diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py index 9e1bec372..78bed64a7 100644 --- a/backend/app/tests/crud/test_credentials.py +++ b/backend/app/tests/crud/test_credentials.py @@ -208,7 +208,7 @@ def test_invalid_provider(db: Session) -> None: """Test handling of invalid provider names.""" project = create_test_project(db) - credentials_data = {"gemini": {"api_key": "test-key"}} + credentials_data = {"invalid_provider": {"api_key": "test-key"}} credentials_create = CredsCreate( is_active=True, credential=credentials_data, From 4a92416d1ec6d5374c841ac5c91ae3c181b00bb4 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 12:28:04 +0530 Subject: [PATCH 34/36] coderabbit suggestions --- backend/app/api/routes/stt_evaluations/dataset.py | 1 - backend/app/api/routes/stt_evaluations/evaluation.py | 1 - backend/app/core/batch/gemini.py | 2 +- backend/app/models/stt_evaluation.py | 4 ++-- backend/app/services/stt_evaluations/gemini/client.py | 2 +- backend/app/tests/api/routes/test_stt_evaluation.py | 6 +++--- 6 files changed, 7 insertions(+), 9 deletions(-) diff --git a/backend/app/api/routes/stt_evaluations/dataset.py b/backend/app/api/routes/stt_evaluations/dataset.py index 6f0b47210..cb6baff0f 100644 --- a/backend/app/api/routes/stt_evaluations/dataset.py +++ b/backend/app/api/routes/stt_evaluations/dataset.py @@ -158,7 +158,6 @@ def get_dataset( language=dataset.language, object_store_url=dataset.object_store_url, dataset_metadata=dataset.dataset_metadata, - sample_count=samples_total, organization_id=dataset.organization_id, project_id=dataset.project_id, inserted_at=dataset.inserted_at, diff --git a/backend/app/api/routes/stt_evaluations/evaluation.py b/backend/app/api/routes/stt_evaluations/evaluation.py index 2ecdcacee..0fdb59340 100644 --- a/backend/app/api/routes/stt_evaluations/evaluation.py +++ b/backend/app/api/routes/stt_evaluations/evaluation.py @@ -83,7 +83,6 @@ def start_stt_evaluation( dataset_id=run_create.dataset_id, org_id=auth_context.organization_.id, project_id=auth_context.project_.id, - limit=10000, ) # Create result records for each sample and provider diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index e974564e3..1dd28a952 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -55,7 +55,7 @@ class GeminiBatchProvider(BatchProvider): DEFAULT_MODEL = "models/gemini-2.5-pro" - def __init__(self, client: genai.Client, model: str | None = None): + def __init__(self, client: genai.Client, model: str | None = None) -> None: """Initialize the Gemini batch provider. Args: diff --git a/backend/app/models/stt_evaluation.py b/backend/app/models/stt_evaluation.py index 863d8cf3f..9f09a5f2f 100644 --- a/backend/app/models/stt_evaluation.py +++ b/backend/app/models/stt_evaluation.py @@ -291,7 +291,6 @@ class STTDatasetPublic(BaseModel): language: str | None object_store_url: str | None dataset_metadata: dict[str, Any] - sample_count: int = 0 organization_id: int project_id: int inserted_at: datetime @@ -310,8 +309,9 @@ class STTEvaluationRunCreate(BaseModel): run_name: str = Field(..., description="Name for this evaluation run", min_length=1) dataset_id: int = Field(..., description="ID of the STT dataset to evaluate") providers: list[str] = Field( - default=["gemini-2.5-pro"], + default_factory=lambda: ["gemini-2.5-pro"], description="List of STT providers to use", + min_length=1, ) language: str | None = Field(None, description="Override language for all samples") diff --git a/backend/app/services/stt_evaluations/gemini/client.py b/backend/app/services/stt_evaluations/gemini/client.py index 5be4b583a..1ee52d66b 100644 --- a/backend/app/services/stt_evaluations/gemini/client.py +++ b/backend/app/services/stt_evaluations/gemini/client.py @@ -21,7 +21,7 @@ class GeminiClientError(Exception): class GeminiClient: """Wrapper for Google GenAI client with credential management.""" - def __init__(self, api_key: str): + def __init__(self, api_key: str) -> None: """Initialize Gemini client with API key. Args: diff --git a/backend/app/tests/api/routes/test_stt_evaluation.py b/backend/app/tests/api/routes/test_stt_evaluation.py index 3251c9aa1..99b136820 100644 --- a/backend/app/tests/api/routes/test_stt_evaluation.py +++ b/backend/app/tests/api/routes/test_stt_evaluation.py @@ -97,7 +97,7 @@ def test_create_stt_dataset_success( assert data["description"] == "Test STT dataset" assert data["type"] == "stt" assert data["language"] == "en" - assert data["sample_count"] == 2 + assert data["dataset_metadata"]["sample_count"] == 2 assert data["dataset_metadata"]["has_ground_truth_count"] == 1 def test_create_stt_dataset_minimal( @@ -125,7 +125,7 @@ def test_create_stt_dataset_minimal( assert data["name"] == "minimal_stt_dataset" assert data["description"] is None assert data["language"] is None - assert data["sample_count"] == 1 + assert data["dataset_metadata"]["sample_count"] == 1 def test_create_stt_dataset_empty_samples( self, @@ -409,7 +409,7 @@ def test_get_stt_dataset_without_samples( assert data["id"] == dataset.id assert data["samples"] == [] - assert data["sample_count"] == 1 # Count should still be correct + assert data["dataset_metadata"]["sample_count"] == 1 class TestSTTEvaluationRun: From e20441629199dba9886989a8f2465bf77406970c Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 13:06:43 +0530 Subject: [PATCH 35/36] added few more testcases --- backend/app/tests/core/batch/test_gemini.py | 62 +++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py index b00442f43..fc62f7b4c 100644 --- a/backend/app/tests/core/batch/test_gemini.py +++ b/backend/app/tests/core/batch/test_gemini.py @@ -552,3 +552,65 @@ def test_extract_multiple_candidates(self): } text = GeminiBatchProvider._extract_text_from_response_dict(response) assert text == "FirstSecond" + + +class TestExtractTextFromResponse: + """Test cases for _extract_text_from_response static method (object version).""" + + def test_extract_direct_text_attribute(self): + """Test extracting text when response has .text attribute.""" + response = MagicMock() + response.text = "Hello from text attribute" + text = GeminiBatchProvider._extract_text_from_response(response) + assert text == "Hello from text attribute" + + def test_extract_from_candidates_structure(self): + """Test extracting text from candidates when no .text attribute.""" + # Create mock without .text attribute + response = MagicMock(spec=[]) + del response.text # Ensure no text attribute + + # Create candidates structure + part1 = MagicMock() + part1.text = "Part 1" + part2 = MagicMock() + part2.text = " Part 2" + + content = MagicMock() + content.parts = [part1, part2] + + candidate = MagicMock() + candidate.content = content + + response.candidates = [candidate] + + text = GeminiBatchProvider._extract_text_from_response(response) + assert text == "Part 1 Part 2" + + def test_extract_empty_response_no_text_no_candidates(self): + """Test extracting text from response with no text and no candidates.""" + response = MagicMock(spec=[]) + del response.text + del response.candidates + + text = GeminiBatchProvider._extract_text_from_response(response) + assert text == "" + + +class TestCreateSttBatchRequestsMimeTypeFallback: + """Test cases for create_stt_batch_requests MIME type fallback.""" + + def test_unknown_mime_type_defaults_to_audio_mpeg(self): + """Test that unknown file extensions default to audio/mpeg.""" + # URL with no recognizable audio extension + signed_urls = ["https://bucket.s3.amazonaws.com/audio.unknown?signature=xyz"] + prompt = "Transcribe this audio." + + with patch("app.core.batch.gemini.get_mime_from_url", return_value=None): + requests = create_stt_batch_requests(signed_urls, prompt) + + assert len(requests) == 1 + # Check that the request was created with default mime type + # parts[0] is text prompt, parts[1] is file_data + file_data = requests[0]["request"]["contents"][0]["parts"][1]["file_data"] + assert file_data["mime_type"] == "audio/mpeg" From 0210dabd0a9ff1b8b59b034f6caa500012e71a12 Mon Sep 17 00:00:00 2001 From: AkhileshNegi Date: Wed, 4 Feb 2026 13:25:49 +0530 Subject: [PATCH 36/36] added more testcases for coverage --- backend/app/tests/core/test_storage_utils.py | 293 ++++++++++++++ .../services/stt_evaluations/__init__.py | 1 + .../services/stt_evaluations/test_audio.py | 378 ++++++++++++++++++ .../services/stt_evaluations/test_dataset.py | 344 ++++++++++++++++ .../stt_evaluations/test_gemini_client.py | 182 +++++++++ 5 files changed, 1198 insertions(+) create mode 100644 backend/app/tests/core/test_storage_utils.py create mode 100644 backend/app/tests/services/stt_evaluations/__init__.py create mode 100644 backend/app/tests/services/stt_evaluations/test_audio.py create mode 100644 backend/app/tests/services/stt_evaluations/test_dataset.py create mode 100644 backend/app/tests/services/stt_evaluations/test_gemini_client.py diff --git a/backend/app/tests/core/test_storage_utils.py b/backend/app/tests/core/test_storage_utils.py new file mode 100644 index 000000000..e1b702003 --- /dev/null +++ b/backend/app/tests/core/test_storage_utils.py @@ -0,0 +1,293 @@ +"""Test cases for storage utilities.""" + +import json +from datetime import datetime +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest + +from app.core.storage_utils import ( + generate_timestamped_filename, + get_mime_from_url, + upload_csv_to_object_store, + upload_jsonl_to_object_store, + upload_to_object_store, +) + + +class TestGetMimeFromUrl: + """Test cases for get_mime_from_url function.""" + + def test_mp3_url(self): + """Test MIME detection for MP3 files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.mp3" + assert get_mime_from_url(url) == "audio/mpeg" + + def test_wav_url(self): + """Test MIME detection for WAV files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.wav" + assert get_mime_from_url(url) == "audio/x-wav" + + def test_flac_url(self): + """Test MIME detection for FLAC files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.flac" + mime = get_mime_from_url(url) + # FLAC can be detected as audio/flac or audio/x-flac depending on system + assert mime in ("audio/flac", "audio/x-flac") + + def test_ogg_url(self): + """Test MIME detection for OGG files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.ogg" + assert get_mime_from_url(url) == "audio/ogg" + + def test_webm_url(self): + """Test MIME detection for WEBM files.""" + url = "https://bucket.s3.amazonaws.com/audio/test.webm" + mime = get_mime_from_url(url) + # webm can be detected as audio or video depending on system + assert mime in ("audio/webm", "video/webm") + + def test_signed_url_with_query_params(self): + """Test MIME detection for signed URLs with query parameters.""" + url = ( + "https://bucket.s3.amazonaws.com/audio/test.mp3" + "?X-Amz-Signature=abc123&X-Amz-Expires=3600" + ) + assert get_mime_from_url(url) == "audio/mpeg" + + def test_url_encoded_path(self): + """Test MIME detection for URL-encoded paths.""" + url = "https://bucket.s3.amazonaws.com/audio/test%20file.mp3" + assert get_mime_from_url(url) == "audio/mpeg" + + def test_unknown_extension(self): + """Test MIME detection returns None for unknown extensions.""" + url = "https://bucket.s3.amazonaws.com/file.unknown" + assert get_mime_from_url(url) is None + + def test_no_extension(self): + """Test MIME detection returns None for URLs without extension.""" + url = "https://bucket.s3.amazonaws.com/file" + assert get_mime_from_url(url) is None + + def test_csv_url(self): + """Test MIME detection for CSV files.""" + url = "https://bucket.s3.amazonaws.com/data/test.csv" + assert get_mime_from_url(url) == "text/csv" + + def test_json_url(self): + """Test MIME detection for JSON files.""" + url = "https://bucket.s3.amazonaws.com/data/test.json" + assert get_mime_from_url(url) == "application/json" + + +class TestUploadToObjectStore: + """Test cases for upload_to_object_store function.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock CloudStorage instance.""" + storage = MagicMock() + storage.put.return_value = "s3://bucket/test/file.txt" + return storage + + def test_successful_upload(self, mock_storage): + """Test successful file upload.""" + content = b"test content" + result = upload_to_object_store( + storage=mock_storage, + content=content, + filename="test.txt", + subdirectory="uploads", + content_type="text/plain", + ) + + assert result == "s3://bucket/test/file.txt" + mock_storage.put.assert_called_once() + + # Verify the UploadFile was created correctly + call_args = mock_storage.put.call_args + upload_file = call_args.kwargs["source"] + assert upload_file.filename == "test.txt" + + def test_upload_with_default_content_type(self, mock_storage): + """Test upload uses default content type.""" + content = b"binary data" + result = upload_to_object_store( + storage=mock_storage, + content=content, + filename="data.bin", + subdirectory="files", + ) + + assert result == "s3://bucket/test/file.txt" + mock_storage.put.assert_called_once() + + def test_upload_returns_none_on_cloud_storage_error(self, mock_storage): + """Test that CloudStorageError returns None gracefully.""" + from app.core.cloud.storage import CloudStorageError + + mock_storage.put.side_effect = CloudStorageError("Connection failed") + + result = upload_to_object_store( + storage=mock_storage, + content=b"data", + filename="test.txt", + subdirectory="uploads", + ) + + assert result is None + + def test_upload_returns_none_on_generic_error(self, mock_storage): + """Test that generic exceptions return None gracefully.""" + mock_storage.put.side_effect = Exception("Unexpected error") + + result = upload_to_object_store( + storage=mock_storage, + content=b"data", + filename="test.txt", + subdirectory="uploads", + ) + + assert result is None + + +class TestUploadCsvToObjectStore: + """Test cases for upload_csv_to_object_store function.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock CloudStorage instance.""" + storage = MagicMock() + storage.put.return_value = "s3://bucket/datasets/data.csv" + return storage + + def test_successful_csv_upload(self, mock_storage): + """Test successful CSV upload.""" + csv_content = b"col1,col2\nval1,val2" + result = upload_csv_to_object_store( + storage=mock_storage, + csv_content=csv_content, + filename="data.csv", + ) + + assert result == "s3://bucket/datasets/data.csv" + mock_storage.put.assert_called_once() + + def test_csv_upload_with_custom_subdirectory(self, mock_storage): + """Test CSV upload with custom subdirectory.""" + csv_content = b"col1,col2\nval1,val2" + result = upload_csv_to_object_store( + storage=mock_storage, + csv_content=csv_content, + filename="data.csv", + subdirectory="stt_datasets", + ) + + assert result is not None + call_args = mock_storage.put.call_args + file_path = call_args.kwargs["file_path"] + assert "stt_datasets" in str(file_path) + + +class TestUploadJsonlToObjectStore: + """Test cases for upload_jsonl_to_object_store function.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock CloudStorage instance.""" + storage = MagicMock() + storage.put.return_value = "s3://bucket/results/output.jsonl" + return storage + + def test_successful_jsonl_upload(self, mock_storage): + """Test successful JSONL upload.""" + results = [ + {"id": 1, "text": "result 1"}, + {"id": 2, "text": "result 2"}, + ] + result = upload_jsonl_to_object_store( + storage=mock_storage, + results=results, + filename="output.jsonl", + subdirectory="results", + ) + + assert result == "s3://bucket/results/output.jsonl" + mock_storage.put.assert_called_once() + + def test_empty_results_list(self, mock_storage): + """Test upload with empty results list.""" + results = [] + result = upload_jsonl_to_object_store( + storage=mock_storage, + results=results, + filename="empty.jsonl", + subdirectory="results", + ) + + assert result is not None + + def test_results_with_unicode(self, mock_storage): + """Test upload with unicode content.""" + results = [ + {"text": "Hello δΈ–η•Œ"}, + {"text": "Emoji 🎡"}, + ] + result = upload_jsonl_to_object_store( + storage=mock_storage, + results=results, + filename="unicode.jsonl", + subdirectory="results", + ) + + assert result is not None + + +class TestGenerateTimestampedFilename: + """Test cases for generate_timestamped_filename function.""" + + def test_default_csv_extension(self): + """Test that default extension is CSV.""" + filename = generate_timestamped_filename("dataset") + assert filename.endswith(".csv") + assert filename.startswith("dataset_") + + def test_custom_extension(self): + """Test custom file extension.""" + filename = generate_timestamped_filename("results", extension="jsonl") + assert filename.endswith(".jsonl") + assert filename.startswith("results_") + + def test_timestamp_format(self): + """Test that timestamp is in expected format.""" + filename = generate_timestamped_filename("test") + # Expected format: test_YYYYMMDD_HHMMSS.csv + parts = filename.split("_") + assert len(parts) == 3 + # Date part should be 8 digits + assert len(parts[1]) == 8 + assert parts[1].isdigit() + # Time part should be 6 digits + extension + time_part = parts[2].split(".")[0] + assert len(time_part) == 6 + assert time_part.isdigit() + + def test_with_special_characters_in_base_name(self): + """Test with special characters in base name.""" + filename = generate_timestamped_filename("my-dataset-v1") + assert filename.startswith("my-dataset-v1_") + assert filename.endswith(".csv") + + def test_unique_filenames(self): + """Test that consecutive calls produce different filenames.""" + import time + + filename1 = generate_timestamped_filename("test") + time.sleep(0.01) # Small delay to ensure different timestamp + filename2 = generate_timestamped_filename("test") + # They may be the same if called in the same second + # but the format should be correct + assert filename1.startswith("test_") + assert filename2.startswith("test_") diff --git a/backend/app/tests/services/stt_evaluations/__init__.py b/backend/app/tests/services/stt_evaluations/__init__.py new file mode 100644 index 000000000..a206798fe --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/__init__.py @@ -0,0 +1 @@ +"""STT evaluation service tests.""" diff --git a/backend/app/tests/services/stt_evaluations/test_audio.py b/backend/app/tests/services/stt_evaluations/test_audio.py new file mode 100644 index 000000000..900f8caee --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/test_audio.py @@ -0,0 +1,378 @@ +"""Test cases for STT audio validation and upload service.""" + +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import UploadFile + +from app.services.stt_evaluations.audio import ( + AudioValidationError, + get_extension_from_content_type, + get_extension_from_filename, + upload_audio_file, + validate_audio_file, + validate_s3_audio_url, +) +from app.services.stt_evaluations.constants import MAX_FILE_SIZE_BYTES + + +class TestGetExtensionFromFilename: + """Test cases for get_extension_from_filename function.""" + + def test_mp3_extension(self): + """Test extracting MP3 extension.""" + assert get_extension_from_filename("audio.mp3") == "mp3" + + def test_wav_extension(self): + """Test extracting WAV extension.""" + assert get_extension_from_filename("audio.wav") == "wav" + + def test_flac_extension(self): + """Test extracting FLAC extension.""" + assert get_extension_from_filename("audio.flac") == "flac" + + def test_m4a_extension(self): + """Test extracting M4A extension.""" + assert get_extension_from_filename("audio.m4a") == "m4a" + + def test_ogg_extension(self): + """Test extracting OGG extension.""" + assert get_extension_from_filename("audio.ogg") == "ogg" + + def test_webm_extension(self): + """Test extracting WEBM extension.""" + assert get_extension_from_filename("audio.webm") == "webm" + + def test_uppercase_extension(self): + """Test that uppercase extensions are normalized to lowercase.""" + assert get_extension_from_filename("audio.MP3") == "mp3" + assert get_extension_from_filename("audio.WAV") == "wav" + + def test_mixed_case_extension(self): + """Test mixed case extensions.""" + assert get_extension_from_filename("audio.Mp3") == "mp3" + + def test_empty_filename(self): + """Test empty filename returns None.""" + assert get_extension_from_filename("") is None + + def test_none_filename(self): + """Test None filename returns None.""" + assert get_extension_from_filename(None) is None + + def test_no_extension(self): + """Test filename without extension returns None.""" + assert get_extension_from_filename("audiofile") is None + + def test_multiple_dots(self): + """Test filename with multiple dots.""" + assert get_extension_from_filename("audio.backup.mp3") == "mp3" + + def test_hidden_file_with_extension(self): + """Test hidden file with extension.""" + assert get_extension_from_filename(".audio.mp3") == "mp3" + + def test_path_with_filename(self): + """Test full path with filename.""" + assert get_extension_from_filename("/path/to/audio.mp3") == "mp3" + + +class TestGetExtensionFromContentType: + """Test cases for get_extension_from_content_type function.""" + + def test_audio_mpeg(self): + """Test audio/mpeg content type.""" + assert get_extension_from_content_type("audio/mpeg") == "mp3" + + def test_audio_mp3(self): + """Test audio/mp3 content type.""" + assert get_extension_from_content_type("audio/mp3") == "mp3" + + def test_audio_wav(self): + """Test audio/wav content type.""" + assert get_extension_from_content_type("audio/wav") == "wav" + + def test_audio_x_wav(self): + """Test audio/x-wav content type.""" + assert get_extension_from_content_type("audio/x-wav") == "wav" + + def test_audio_wave(self): + """Test audio/wave content type.""" + assert get_extension_from_content_type("audio/wave") == "wav" + + def test_audio_flac(self): + """Test audio/flac content type.""" + assert get_extension_from_content_type("audio/flac") == "flac" + + def test_audio_mp4(self): + """Test audio/mp4 content type (m4a).""" + assert get_extension_from_content_type("audio/mp4") == "m4a" + + def test_audio_ogg(self): + """Test audio/ogg content type.""" + assert get_extension_from_content_type("audio/ogg") == "ogg" + + def test_audio_webm(self): + """Test audio/webm content type.""" + assert get_extension_from_content_type("audio/webm") == "webm" + + def test_uppercase_content_type(self): + """Test uppercase content type is normalized.""" + assert get_extension_from_content_type("AUDIO/MPEG") == "mp3" + + def test_empty_content_type(self): + """Test empty content type returns None.""" + assert get_extension_from_content_type("") is None + + def test_none_content_type(self): + """Test None content type returns None.""" + assert get_extension_from_content_type(None) is None + + def test_unknown_content_type(self): + """Test unknown content type returns None.""" + assert get_extension_from_content_type("application/octet-stream") is None + + +class TestValidateAudioFile: + """Test cases for validate_audio_file function.""" + + def _create_upload_file( + self, + filename: str | None = "test.mp3", + content_type: str | None = "audio/mpeg", + size: int | None = 1024, + ) -> UploadFile: + """Create a mock UploadFile for testing.""" + mock_file = MagicMock(spec=UploadFile) + mock_file.filename = filename + mock_file.content_type = content_type + mock_file.size = size + return mock_file + + def test_valid_mp3_file(self): + """Test validation of valid MP3 file.""" + file = self._create_upload_file(filename="test.mp3") + result = validate_audio_file(file) + assert result == "mp3" + + def test_valid_wav_file(self): + """Test validation of valid WAV file.""" + file = self._create_upload_file(filename="test.wav") + result = validate_audio_file(file) + assert result == "wav" + + def test_valid_flac_file(self): + """Test validation of valid FLAC file.""" + file = self._create_upload_file(filename="test.flac") + result = validate_audio_file(file) + assert result == "flac" + + def test_valid_m4a_file(self): + """Test validation of valid M4A file.""" + file = self._create_upload_file(filename="test.m4a") + result = validate_audio_file(file) + assert result == "m4a" + + def test_valid_ogg_file(self): + """Test validation of valid OGG file.""" + file = self._create_upload_file(filename="test.ogg") + result = validate_audio_file(file) + assert result == "ogg" + + def test_valid_webm_file(self): + """Test validation of valid WEBM file.""" + file = self._create_upload_file(filename="test.webm") + result = validate_audio_file(file) + assert result == "webm" + + def test_missing_filename(self): + """Test validation fails when filename is missing.""" + file = self._create_upload_file(filename=None) + with pytest.raises(AudioValidationError) as exc_info: + validate_audio_file(file) + assert "Filename is required" in str(exc_info.value) + + def test_empty_filename(self): + """Test validation fails when filename is empty.""" + file = self._create_upload_file(filename="") + with pytest.raises(AudioValidationError) as exc_info: + validate_audio_file(file) + assert "Filename is required" in str(exc_info.value) + + def test_unsupported_format(self): + """Test validation fails for unsupported format.""" + file = self._create_upload_file(filename="test.txt") + with pytest.raises(AudioValidationError) as exc_info: + validate_audio_file(file) + assert "Unsupported audio format" in str(exc_info.value) + + def test_extension_from_content_type_fallback(self): + """Test fallback to content type when filename has no extension.""" + file = self._create_upload_file(filename="audiofile", content_type="audio/mpeg") + result = validate_audio_file(file) + assert result == "mp3" + + def test_file_too_large(self): + """Test validation fails when file is too large.""" + file = self._create_upload_file( + filename="test.mp3", + size=MAX_FILE_SIZE_BYTES + 1, + ) + with pytest.raises(AudioValidationError) as exc_info: + validate_audio_file(file) + assert "File too large" in str(exc_info.value) + + def test_file_at_max_size(self): + """Test validation passes when file is exactly at max size.""" + file = self._create_upload_file( + filename="test.mp3", + size=MAX_FILE_SIZE_BYTES, + ) + result = validate_audio_file(file) + assert result == "mp3" + + def test_file_with_no_size(self): + """Test validation passes when file size is not available.""" + file = self._create_upload_file(filename="test.mp3", size=None) + result = validate_audio_file(file) + assert result == "mp3" + + +class TestValidateS3AudioUrl: + """Test cases for validate_s3_audio_url function.""" + + def test_valid_s3_mp3_url(self): + """Test valid S3 MP3 URL.""" + url = "s3://bucket/audio/test.mp3" + assert validate_s3_audio_url(url) is True + + def test_valid_s3_wav_url(self): + """Test valid S3 WAV URL.""" + url = "s3://bucket/audio/test.wav" + assert validate_s3_audio_url(url) is True + + def test_valid_s3_flac_url(self): + """Test valid S3 FLAC URL.""" + url = "s3://bucket/audio/test.flac" + assert validate_s3_audio_url(url) is True + + def test_empty_url(self): + """Test empty URL returns False.""" + assert validate_s3_audio_url("") is False + + def test_none_url(self): + """Test None URL returns False.""" + assert validate_s3_audio_url(None) is False + + def test_https_url(self): + """Test HTTPS URL returns False.""" + url = "https://bucket.s3.amazonaws.com/audio/test.mp3" + assert validate_s3_audio_url(url) is False + + def test_s3_url_with_invalid_extension(self): + """Test S3 URL with invalid extension returns False.""" + url = "s3://bucket/file.txt" + assert validate_s3_audio_url(url) is False + + def test_s3_url_with_no_extension(self): + """Test S3 URL with no extension returns False.""" + url = "s3://bucket/audiofile" + assert validate_s3_audio_url(url) is False + + +class TestUploadAudioFile: + """Test cases for upload_audio_file function.""" + + def _create_upload_file( + self, + filename: str = "test.mp3", + content_type: str = "audio/mpeg", + size: int = 1024, + ) -> UploadFile: + """Create a mock UploadFile for testing.""" + mock_file = MagicMock(spec=UploadFile) + mock_file.filename = filename + mock_file.content_type = content_type + mock_file.size = size + return mock_file + + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_successful_upload(self, mock_get_storage): + """Test successful audio file upload.""" + mock_storage = MagicMock() + mock_storage.put.return_value = "s3://bucket/stt/audio/test.mp3" + mock_storage.get_file_size_kb.return_value = 1.0 + mock_get_storage.return_value = mock_storage + + mock_session = MagicMock() + file = self._create_upload_file() + + result = upload_audio_file( + session=mock_session, + file=file, + project_id=1, + ) + + assert result.s3_url == "s3://bucket/stt/audio/test.mp3" + assert result.filename == "test.mp3" + assert result.size_bytes == 1024 + assert result.content_type == "audio/mpeg" + + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_upload_validation_error(self, mock_get_storage): + """Test upload fails on validation error.""" + from app.core.exception_handlers import HTTPException + + mock_session = MagicMock() + file = self._create_upload_file(filename="test.txt") + + with pytest.raises(HTTPException) as exc_info: + upload_audio_file( + session=mock_session, + file=file, + project_id=1, + ) + + assert exc_info.value.status_code == 400 + assert "Unsupported audio format" in str(exc_info.value.detail) + + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_upload_storage_error(self, mock_get_storage): + """Test upload handles storage errors.""" + from app.core.exception_handlers import HTTPException + + mock_storage = MagicMock() + mock_storage.put.side_effect = Exception("S3 connection failed") + mock_get_storage.return_value = mock_storage + + mock_session = MagicMock() + file = self._create_upload_file() + + with pytest.raises(HTTPException) as exc_info: + upload_audio_file( + session=mock_session, + file=file, + project_id=1, + ) + + assert exc_info.value.status_code == 500 + assert "Failed to upload audio file" in str(exc_info.value.detail) + + @patch("app.services.stt_evaluations.audio.get_cloud_storage") + def test_upload_uses_file_size_on_s3_error(self, mock_get_storage): + """Test upload uses file.size when S3 size retrieval fails.""" + mock_storage = MagicMock() + mock_storage.put.return_value = "s3://bucket/stt/audio/test.mp3" + mock_storage.get_file_size_kb.side_effect = Exception("Failed to get size") + mock_get_storage.return_value = mock_storage + + mock_session = MagicMock() + file = self._create_upload_file(size=2048) + + result = upload_audio_file( + session=mock_session, + file=file, + project_id=1, + ) + + assert result.size_bytes == 2048 diff --git a/backend/app/tests/services/stt_evaluations/test_dataset.py b/backend/app/tests/services/stt_evaluations/test_dataset.py new file mode 100644 index 000000000..87dcda18c --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/test_dataset.py @@ -0,0 +1,344 @@ +"""Test cases for STT dataset management service.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from app.models.stt_evaluation import STTSampleCreate +from app.services.stt_evaluations.dataset import ( + _samples_to_csv, + _upload_samples_to_object_store, + upload_stt_dataset, +) + + +class TestSamplesToCSV: + """Test cases for _samples_to_csv function.""" + + def test_single_sample_without_ground_truth(self): + """Test CSV conversion with single sample without ground truth.""" + samples = [ + STTSampleCreate(object_store_url="s3://bucket/audio1.mp3"), + ] + result = _samples_to_csv(samples) + + # Decode and verify - handle both \n and \r\n line endings + csv_str = result.decode("utf-8").replace("\r\n", "\n").replace("\r", "\n") + lines = csv_str.strip().split("\n") + + assert len(lines) == 2 # Header + 1 sample + assert lines[0] == "object_store_url,ground_truth" + assert lines[1] == "s3://bucket/audio1.mp3," + + def test_single_sample_with_ground_truth(self): + """Test CSV conversion with single sample with ground truth.""" + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio1.mp3", + ground_truth="Hello world", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + lines = csv_str.strip().split("\n") + + assert len(lines) == 2 + assert "Hello world" in lines[1] + + def test_multiple_samples(self): + """Test CSV conversion with multiple samples.""" + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio1.mp3", + ground_truth="First transcription", + ), + STTSampleCreate( + object_store_url="s3://bucket/audio2.mp3", + ground_truth="Second transcription", + ), + STTSampleCreate( + object_store_url="s3://bucket/audio3.mp3", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + lines = csv_str.strip().split("\n") + + assert len(lines) == 4 # Header + 3 samples + + def test_empty_samples_list(self): + """Test CSV conversion with empty samples list.""" + samples = [] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + lines = csv_str.strip().split("\n") + + # Should only have header + assert len(lines) == 1 + assert lines[0] == "object_store_url,ground_truth" + + def test_sample_with_unicode(self): + """Test CSV conversion with unicode characters.""" + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio.mp3", + ground_truth="Hello δΈ–η•Œ 🌍", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + assert "Hello δΈ–η•Œ 🌍" in csv_str + + def test_sample_with_comma_in_ground_truth(self): + """Test CSV conversion handles commas in ground truth.""" + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio.mp3", + ground_truth="Hello, world", + ), + ] + result = _samples_to_csv(samples) + + csv_str = result.decode("utf-8") + # CSV should properly quote fields with commas + assert '"Hello, world"' in csv_str or "Hello, world" in csv_str + + def test_sample_with_quotes_in_ground_truth(self): + """Test CSV conversion handles quotes in ground truth.""" + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio.mp3", + ground_truth='He said "hello"', + ), + ] + result = _samples_to_csv(samples) + + # Should produce valid CSV bytes + assert isinstance(result, bytes) + + +class TestUploadSamplesToObjectStore: + """Test cases for _upload_samples_to_object_store function.""" + + @patch("app.services.stt_evaluations.dataset.get_cloud_storage") + @patch("app.services.stt_evaluations.dataset.upload_csv_to_object_store") + def test_successful_upload(self, mock_upload_csv, mock_get_storage): + """Test successful upload to object store.""" + mock_storage = MagicMock() + mock_get_storage.return_value = mock_storage + mock_upload_csv.return_value = "s3://bucket/stt_datasets/dataset.csv" + + mock_session = MagicMock() + samples = [ + STTSampleCreate(object_store_url="s3://bucket/audio.mp3"), + ] + + result = _upload_samples_to_object_store( + session=mock_session, + project_id=1, + dataset_name="test_dataset", + samples=samples, + ) + + assert result == "s3://bucket/stt_datasets/dataset.csv" + mock_upload_csv.assert_called_once() + + @patch("app.services.stt_evaluations.dataset.get_cloud_storage") + @patch("app.services.stt_evaluations.dataset.upload_csv_to_object_store") + def test_upload_returns_none_on_failure(self, mock_upload_csv, mock_get_storage): + """Test upload returns None on failure.""" + mock_storage = MagicMock() + mock_get_storage.return_value = mock_storage + mock_upload_csv.return_value = None + + mock_session = MagicMock() + samples = [ + STTSampleCreate(object_store_url="s3://bucket/audio.mp3"), + ] + + result = _upload_samples_to_object_store( + session=mock_session, + project_id=1, + dataset_name="test_dataset", + samples=samples, + ) + + assert result is None + + @patch("app.services.stt_evaluations.dataset.get_cloud_storage") + def test_upload_handles_exception(self, mock_get_storage): + """Test upload handles exceptions gracefully.""" + mock_get_storage.side_effect = Exception("Storage connection failed") + + mock_session = MagicMock() + samples = [ + STTSampleCreate(object_store_url="s3://bucket/audio.mp3"), + ] + + result = _upload_samples_to_object_store( + session=mock_session, + project_id=1, + dataset_name="test_dataset", + samples=samples, + ) + + assert result is None + + +class TestUploadSTTDataset: + """Test cases for upload_stt_dataset function.""" + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_successful_upload( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test successful dataset upload orchestration.""" + # Setup mocks + mock_upload_samples.return_value = "s3://bucket/stt_datasets/test.csv" + + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + + mock_sample = MagicMock() + mock_create_samples.return_value = [mock_sample] + + mock_session = MagicMock() + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio.mp3", + ground_truth="Test transcription", + ), + ] + + dataset, created_samples = upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + ) + + assert dataset == mock_dataset + assert len(created_samples) == 1 + + # Verify create_stt_dataset was called with correct metadata + mock_create_dataset.assert_called_once() + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["name"] == "test_dataset" + assert call_kwargs["dataset_metadata"]["sample_count"] == 1 + assert call_kwargs["dataset_metadata"]["has_ground_truth_count"] == 1 + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_upload_with_description_and_language( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test upload with optional description and language.""" + mock_upload_samples.return_value = None + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + mock_create_samples.return_value = [] + + mock_session = MagicMock() + samples = [ + STTSampleCreate(object_store_url="s3://bucket/audio.mp3"), + ] + + upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + description="Test description", + language="en", + ) + + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["description"] == "Test description" + assert call_kwargs["language"] == "en" + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_upload_counts_ground_truth_correctly( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test that ground truth count is calculated correctly.""" + mock_upload_samples.return_value = None + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + mock_create_samples.return_value = [] + + mock_session = MagicMock() + samples = [ + STTSampleCreate( + object_store_url="s3://bucket/audio1.mp3", + ground_truth="Has ground truth", + ), + STTSampleCreate( + object_store_url="s3://bucket/audio2.mp3", + ), + STTSampleCreate( + object_store_url="s3://bucket/audio3.mp3", + ground_truth="Also has ground truth", + ), + STTSampleCreate( + object_store_url="s3://bucket/audio4.mp3", + ground_truth="", # Empty string should not count + ), + ] + + upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + ) + + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["dataset_metadata"]["sample_count"] == 4 + assert call_kwargs["dataset_metadata"]["has_ground_truth_count"] == 2 + + @patch("app.services.stt_evaluations.dataset._upload_samples_to_object_store") + @patch("app.services.stt_evaluations.dataset.create_stt_samples") + @patch("app.services.stt_evaluations.dataset.create_stt_dataset") + def test_upload_continues_without_object_store_url( + self, mock_create_dataset, mock_create_samples, mock_upload_samples + ): + """Test that upload continues even when object store upload fails.""" + mock_upload_samples.return_value = None # Simulates failed upload + + mock_dataset = MagicMock() + mock_dataset.id = 1 + mock_create_dataset.return_value = mock_dataset + mock_create_samples.return_value = [] + + mock_session = MagicMock() + samples = [ + STTSampleCreate(object_store_url="s3://bucket/audio.mp3"), + ] + + dataset, created_samples = upload_stt_dataset( + session=mock_session, + name="test_dataset", + samples=samples, + organization_id=1, + project_id=1, + ) + + # Should still create the dataset + assert dataset is not None + call_kwargs = mock_create_dataset.call_args.kwargs + assert call_kwargs["object_store_url"] is None diff --git a/backend/app/tests/services/stt_evaluations/test_gemini_client.py b/backend/app/tests/services/stt_evaluations/test_gemini_client.py new file mode 100644 index 000000000..36ef02fbb --- /dev/null +++ b/backend/app/tests/services/stt_evaluations/test_gemini_client.py @@ -0,0 +1,182 @@ +"""Test cases for Gemini client wrapper.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.stt_evaluations.gemini.client import ( + GeminiClient, + GeminiClientError, +) + + +class TestGeminiClientInit: + """Test cases for GeminiClient initialization.""" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_initialization_with_api_key(self, mock_genai_client): + """Test client initialization with API key.""" + mock_client_instance = MagicMock() + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="test-api-key") + + mock_genai_client.assert_called_once_with(api_key="test-api-key") + assert client._api_key == "test-api-key" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_client_property(self, mock_genai_client): + """Test client property returns underlying client.""" + mock_client_instance = MagicMock() + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="test-api-key") + + assert client.client == mock_client_instance + + +class TestGeminiClientFromCredentials: + """Test cases for GeminiClient.from_credentials class method.""" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_successful_creation(self, mock_get_creds, mock_genai_client): + """Test successful client creation from credentials.""" + mock_get_creds.return_value = {"api_key": "stored-api-key"} + mock_client_instance = MagicMock() + mock_genai_client.return_value = mock_client_instance + + mock_session = MagicMock() + + client = GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + mock_get_creds.assert_called_once_with( + session=mock_session, + org_id=1, + project_id=2, + provider="gemini", + ) + assert client._api_key == "stored-api-key" + + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_credentials_not_found(self, mock_get_creds): + """Test error when credentials are not found.""" + from app.core.exception_handlers import HTTPException + + mock_get_creds.return_value = None + mock_session = MagicMock() + + with pytest.raises(HTTPException) as exc_info: + GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + assert exc_info.value.status_code == 404 + assert "credentials not configured" in str(exc_info.value.detail) + + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_missing_api_key(self, mock_get_creds): + """Test error when credentials exist but api_key is missing.""" + mock_get_creds.return_value = {"other_field": "value"} # No api_key + mock_session = MagicMock() + + with pytest.raises(GeminiClientError) as exc_info: + GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + assert "missing api_key" in str(exc_info.value) + + @patch("app.services.stt_evaluations.gemini.client.get_provider_credential") + def test_empty_api_key(self, mock_get_creds): + """Test error when api_key is empty.""" + mock_get_creds.return_value = {"api_key": ""} # Empty api_key + mock_session = MagicMock() + + with pytest.raises(GeminiClientError) as exc_info: + GeminiClient.from_credentials( + session=mock_session, + org_id=1, + project_id=2, + ) + + assert "missing api_key" in str(exc_info.value) + + +class TestGeminiClientValidateConnection: + """Test cases for GeminiClient.validate_connection method.""" + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_successful_validation(self, mock_genai_client): + """Test successful connection validation.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.return_value = [ + MagicMock(), + MagicMock(), + ] + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="valid-api-key") + result = client.validate_connection() + + assert result is True + mock_client_instance.models.list.assert_called_once() + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_validation_with_api_error(self, mock_genai_client): + """Test validation returns False on API error.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.side_effect = Exception("API error") + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="invalid-api-key") + result = client.validate_connection() + + assert result is False + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_validation_with_auth_error(self, mock_genai_client): + """Test validation returns False on authentication error.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.side_effect = Exception("Invalid API key") + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="wrong-api-key") + result = client.validate_connection() + + assert result is False + + @patch("app.services.stt_evaluations.gemini.client.genai.Client") + def test_validation_with_empty_models_list(self, mock_genai_client): + """Test validation succeeds with empty models list.""" + mock_client_instance = MagicMock() + mock_client_instance.models.list.return_value = [] + mock_genai_client.return_value = mock_client_instance + + client = GeminiClient(api_key="valid-api-key") + result = client.validate_connection() + + # Empty list is still a valid response + assert result is True + + +class TestGeminiClientError: + """Test cases for GeminiClientError exception.""" + + def test_error_message(self): + """Test error message is preserved.""" + error = GeminiClientError("Test error message") + assert str(error) == "Test error message" + + def test_error_inheritance(self): + """Test error inherits from Exception.""" + error = GeminiClientError("Test") + assert isinstance(error, Exception)