Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: tests

on:
push:
branches: ["main", "feat/**", "fix/**", "chore/**"]
pull_request:
branches: ["main"]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: install test dependencies
run: pip install -r requirements-dev.txt

- name: run tests
run: python -m pytest tests/ -v --tb=short
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest>=7.4.0
Empty file added src/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions src/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Evaluation module for the CC suggestion pipeline (Goal 5)."""

from src.eval.evaluator import EvalReport, GroundTruthEvent, DetectedEvent, evaluate

__all__ = ["evaluate", "EvalReport", "GroundTruthEvent", "DetectedEvent"]
233 changes: 233 additions & 0 deletions src/eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""Evaluation framework for CC suggestion pipelines (Goal 5).

Computes precision, recall, F1, and overcaption rate by comparing
a pipeline's detected events against a hand-annotated ground truth.
Works with any pipeline — only requires label and timestamp fields.

No ML dependencies. All matching is purely temporal (IoU-based).
"""

from __future__ import annotations

import dataclasses
import json
from pathlib import Path


# ---------------------------------------------------------------------------
# Data types
# ---------------------------------------------------------------------------

@dataclasses.dataclass
class DetectedEvent:
"""A single event emitted by a detection pipeline."""

label: str
start_s: float
end_s: float
confidence: float = 1.0


@dataclasses.dataclass
class GroundTruthEvent:
"""A hand-annotated non-speech audio event."""

label: str
start_s: float
end_s: float


@dataclasses.dataclass
class LabelMetrics:
"""Per-label precision / recall / F1 breakdown."""

true_positives: int
false_positives: int
false_negatives: int
precision: float
recall: float
f1: float


@dataclasses.dataclass
class EvalReport:
"""Aggregate evaluation results for one pipeline run."""

true_positives: int
false_positives: int
false_negatives: int
precision: float
recall: float
f1: float
# FP / total_detected: how often the pipeline fires without a real event
overcaption_rate: float
total_detected: int
total_ground_truth: int
per_label: dict[str, LabelMetrics]

def to_dict(self) -> dict:
return {
"precision": round(self.precision, 4),
"recall": round(self.recall, 4),
"f1": round(self.f1, 4),
"overcaption_rate": round(self.overcaption_rate, 4),
"true_positives": self.true_positives,
"false_positives": self.false_positives,
"false_negatives": self.false_negatives,
"total_detected": self.total_detected,
"total_ground_truth": self.total_ground_truth,
"per_label": {
label: {
"precision": round(m.precision, 4),
"recall": round(m.recall, 4),
"f1": round(m.f1, 4),
"true_positives": m.true_positives,
"false_positives": m.false_positives,
"false_negatives": m.false_negatives,
}
for label, m in self.per_label.items()
},
}


# ---------------------------------------------------------------------------
# Public helpers
# ---------------------------------------------------------------------------

def load_ground_truth(path: str) -> list[GroundTruthEvent]:
"""Load ground truth annotations from a JSON file.

Expected format::

[
{"label": "[alarm]", "start_s": 4.32, "end_s": 6.72},
{"label": "[gunshot]", "start_s": 12.00, "end_s": 12.96}
]

Parameters
----------
path:
Path to the annotation JSON file.
"""
raw = json.loads(Path(path).read_text(encoding="utf-8"))
return [
GroundTruthEvent(label=r["label"], start_s=float(r["start_s"]), end_s=float(r["end_s"]))
for r in raw
]


def evaluate(
detected: list[DetectedEvent],
ground_truth: list[GroundTruthEvent],
iou_threshold: float = 0.5,
) -> EvalReport:
"""Compare *detected* events against *ground_truth* and return metrics.

Matching is label-aware and greedy: for each ground truth event, the
highest-IoU detection of the same label (above *iou_threshold*) is
consumed as a true positive. Unmatched detections are false positives;
unmatched ground truth events are false negatives.

Parameters
----------
detected:
Events produced by the pipeline under evaluation.
ground_truth:
Hand-annotated reference events for the same clip.
iou_threshold:
Minimum temporal IoU required to count as a match. Default 0.5.

Returns
-------
EvalReport
Aggregate and per-label metrics.
"""
all_labels = {ev.label for ev in detected} | {ev.label for ev in ground_truth}
per_label: dict[str, LabelMetrics] = {}

total_tp = total_fp = total_fn = 0

for label in sorted(all_labels):
det = [e for e in detected if e.label == label]
gt = [e for e in ground_truth if e.label == label]

tp, matched_det = _match(det, gt, iou_threshold)
fp = len(det) - tp
fn = len(gt) - tp

per_label[label] = LabelMetrics(
true_positives=tp,
false_positives=fp,
false_negatives=fn,
precision=_safe_div(tp, tp + fp),
recall=_safe_div(tp, tp + fn),
f1=_f1(_safe_div(tp, tp + fp), _safe_div(tp, tp + fn)),
)
total_tp += tp
total_fp += fp
total_fn += fn

total_det = len(detected)
precision = _safe_div(total_tp, total_tp + total_fp)
recall = _safe_div(total_tp, total_tp + total_fn)

return EvalReport(
true_positives=total_tp,
false_positives=total_fp,
false_negatives=total_fn,
precision=precision,
recall=recall,
f1=_f1(precision, recall),
overcaption_rate=_safe_div(total_fp, total_det),
total_detected=total_det,
total_ground_truth=len(ground_truth),
per_label=per_label,
)


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

def _iou(a: DetectedEvent | GroundTruthEvent, b: DetectedEvent | GroundTruthEvent) -> float:
"""Temporal intersection-over-union for two events."""
inter = max(0.0, min(a.end_s, b.end_s) - max(a.start_s, b.start_s))
union = max(a.end_s, b.end_s) - min(a.start_s, b.start_s)
return inter / union if union > 0 else 0.0


def _match(
det: list[DetectedEvent],
gt: list[GroundTruthEvent],
threshold: float,
) -> tuple[int, set[int]]:
"""Greedy IoU matching. Returns (true_positive_count, matched_det_indices)."""
used_det: set[int] = set()
used_gt: set[int] = set()
tp = 0

# Build all (iou, det_idx, gt_idx) pairs above threshold, sort descending
pairs: list[tuple[float, int, int]] = []
for di, d in enumerate(det):
for gi, g in enumerate(gt):
score = _iou(d, g)
if score >= threshold:
pairs.append((score, di, gi))
pairs.sort(key=lambda x: x[0], reverse=True)

for score, di, gi in pairs:
if di in used_det or gi in used_gt:
continue
used_det.add(di)
used_gt.add(gi)
tp += 1

return tp, used_det


def _safe_div(num: float, denom: float) -> float:
return num / denom if denom > 0 else 0.0


def _f1(precision: float, recall: float) -> float:
return _safe_div(2 * precision * recall, precision + recall)
Empty file added tests/__init__.py
Empty file.
Loading