diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8d7a250 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,27 @@ +name: tests + +on: + push: + branches: ["main", "feat/**", "fix/**", "chore/**"] + pull_request: + branches: ["main"] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: install test dependencies + run: pip install -r requirements-dev.txt + + - name: run tests + run: python -m pytest tests/ -v --tb=short diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..5b1f97b --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1 @@ +pytest>=7.4.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/eval/__init__.py b/src/eval/__init__.py new file mode 100644 index 0000000..5c67fd9 --- /dev/null +++ b/src/eval/__init__.py @@ -0,0 +1,5 @@ +"""Evaluation module for the CC suggestion pipeline (Goal 5).""" + +from src.eval.evaluator import EvalReport, GroundTruthEvent, DetectedEvent, evaluate + +__all__ = ["evaluate", "EvalReport", "GroundTruthEvent", "DetectedEvent"] diff --git a/src/eval/evaluator.py b/src/eval/evaluator.py new file mode 100644 index 0000000..e7da977 --- /dev/null +++ b/src/eval/evaluator.py @@ -0,0 +1,233 @@ +"""Evaluation framework for CC suggestion pipelines (Goal 5). + +Computes precision, recall, F1, and overcaption rate by comparing +a pipeline's detected events against a hand-annotated ground truth. +Works with any pipeline — only requires label and timestamp fields. + +No ML dependencies. All matching is purely temporal (IoU-based). +""" + +from __future__ import annotations + +import dataclasses +import json +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + +@dataclasses.dataclass +class DetectedEvent: + """A single event emitted by a detection pipeline.""" + + label: str + start_s: float + end_s: float + confidence: float = 1.0 + + +@dataclasses.dataclass +class GroundTruthEvent: + """A hand-annotated non-speech audio event.""" + + label: str + start_s: float + end_s: float + + +@dataclasses.dataclass +class LabelMetrics: + """Per-label precision / recall / F1 breakdown.""" + + true_positives: int + false_positives: int + false_negatives: int + precision: float + recall: float + f1: float + + +@dataclasses.dataclass +class EvalReport: + """Aggregate evaluation results for one pipeline run.""" + + true_positives: int + false_positives: int + false_negatives: int + precision: float + recall: float + f1: float + # FP / total_detected: how often the pipeline fires without a real event + overcaption_rate: float + total_detected: int + total_ground_truth: int + per_label: dict[str, LabelMetrics] + + def to_dict(self) -> dict: + return { + "precision": round(self.precision, 4), + "recall": round(self.recall, 4), + "f1": round(self.f1, 4), + "overcaption_rate": round(self.overcaption_rate, 4), + "true_positives": self.true_positives, + "false_positives": self.false_positives, + "false_negatives": self.false_negatives, + "total_detected": self.total_detected, + "total_ground_truth": self.total_ground_truth, + "per_label": { + label: { + "precision": round(m.precision, 4), + "recall": round(m.recall, 4), + "f1": round(m.f1, 4), + "true_positives": m.true_positives, + "false_positives": m.false_positives, + "false_negatives": m.false_negatives, + } + for label, m in self.per_label.items() + }, + } + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + +def load_ground_truth(path: str) -> list[GroundTruthEvent]: + """Load ground truth annotations from a JSON file. + + Expected format:: + + [ + {"label": "[alarm]", "start_s": 4.32, "end_s": 6.72}, + {"label": "[gunshot]", "start_s": 12.00, "end_s": 12.96} + ] + + Parameters + ---------- + path: + Path to the annotation JSON file. + """ + raw = json.loads(Path(path).read_text(encoding="utf-8")) + return [ + GroundTruthEvent(label=r["label"], start_s=float(r["start_s"]), end_s=float(r["end_s"])) + for r in raw + ] + + +def evaluate( + detected: list[DetectedEvent], + ground_truth: list[GroundTruthEvent], + iou_threshold: float = 0.5, +) -> EvalReport: + """Compare *detected* events against *ground_truth* and return metrics. + + Matching is label-aware and greedy: for each ground truth event, the + highest-IoU detection of the same label (above *iou_threshold*) is + consumed as a true positive. Unmatched detections are false positives; + unmatched ground truth events are false negatives. + + Parameters + ---------- + detected: + Events produced by the pipeline under evaluation. + ground_truth: + Hand-annotated reference events for the same clip. + iou_threshold: + Minimum temporal IoU required to count as a match. Default 0.5. + + Returns + ------- + EvalReport + Aggregate and per-label metrics. + """ + all_labels = {ev.label for ev in detected} | {ev.label for ev in ground_truth} + per_label: dict[str, LabelMetrics] = {} + + total_tp = total_fp = total_fn = 0 + + for label in sorted(all_labels): + det = [e for e in detected if e.label == label] + gt = [e for e in ground_truth if e.label == label] + + tp, matched_det = _match(det, gt, iou_threshold) + fp = len(det) - tp + fn = len(gt) - tp + + per_label[label] = LabelMetrics( + true_positives=tp, + false_positives=fp, + false_negatives=fn, + precision=_safe_div(tp, tp + fp), + recall=_safe_div(tp, tp + fn), + f1=_f1(_safe_div(tp, tp + fp), _safe_div(tp, tp + fn)), + ) + total_tp += tp + total_fp += fp + total_fn += fn + + total_det = len(detected) + precision = _safe_div(total_tp, total_tp + total_fp) + recall = _safe_div(total_tp, total_tp + total_fn) + + return EvalReport( + true_positives=total_tp, + false_positives=total_fp, + false_negatives=total_fn, + precision=precision, + recall=recall, + f1=_f1(precision, recall), + overcaption_rate=_safe_div(total_fp, total_det), + total_detected=total_det, + total_ground_truth=len(ground_truth), + per_label=per_label, + ) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _iou(a: DetectedEvent | GroundTruthEvent, b: DetectedEvent | GroundTruthEvent) -> float: + """Temporal intersection-over-union for two events.""" + inter = max(0.0, min(a.end_s, b.end_s) - max(a.start_s, b.start_s)) + union = max(a.end_s, b.end_s) - min(a.start_s, b.start_s) + return inter / union if union > 0 else 0.0 + + +def _match( + det: list[DetectedEvent], + gt: list[GroundTruthEvent], + threshold: float, +) -> tuple[int, set[int]]: + """Greedy IoU matching. Returns (true_positive_count, matched_det_indices).""" + used_det: set[int] = set() + used_gt: set[int] = set() + tp = 0 + + # Build all (iou, det_idx, gt_idx) pairs above threshold, sort descending + pairs: list[tuple[float, int, int]] = [] + for di, d in enumerate(det): + for gi, g in enumerate(gt): + score = _iou(d, g) + if score >= threshold: + pairs.append((score, di, gi)) + pairs.sort(key=lambda x: x[0], reverse=True) + + for score, di, gi in pairs: + if di in used_det or gi in used_gt: + continue + used_det.add(di) + used_gt.add(gi) + tp += 1 + + return tp, used_det + + +def _safe_div(num: float, denom: float) -> float: + return num / denom if denom > 0 else 0.0 + + +def _f1(precision: float, recall: float) -> float: + return _safe_div(2 * precision * recall, precision + recall) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_evaluator.py b/tests/test_evaluator.py new file mode 100644 index 0000000..8e9abd2 --- /dev/null +++ b/tests/test_evaluator.py @@ -0,0 +1,293 @@ +"""Tests for src.eval.evaluator — no ML dependencies required.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from src.eval.evaluator import ( + DetectedEvent, + EvalReport, + GroundTruthEvent, + LabelMetrics, + _iou, + _match, + evaluate, + load_ground_truth, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _det(label: str, start: float, end: float, conf: float = 0.9) -> DetectedEvent: + return DetectedEvent(label=label, start_s=start, end_s=end, confidence=conf) + + +def _gt(label: str, start: float, end: float) -> GroundTruthEvent: + return GroundTruthEvent(label=label, start_s=start, end_s=end) + + +# --------------------------------------------------------------------------- +# _iou +# --------------------------------------------------------------------------- + +class TestIou: + def test_perfect_overlap(self): + a = _det("[alarm]", 0.0, 2.0) + b = _gt("[alarm]", 0.0, 2.0) + assert _iou(a, b) == pytest.approx(1.0) + + def test_no_overlap(self): + a = _det("[alarm]", 0.0, 1.0) + b = _gt("[alarm]", 2.0, 3.0) + assert _iou(a, b) == pytest.approx(0.0) + + def test_partial_overlap(self): + # a=[0,2), b=[1,3) → inter=1, union=3 + a = _det("[alarm]", 0.0, 2.0) + b = _gt("[alarm]", 1.0, 3.0) + assert _iou(a, b) == pytest.approx(1.0 / 3.0) + + def test_contained_event(self): + # b fully inside a → inter=1, union=3 + a = _det("[alarm]", 0.0, 3.0) + b = _gt("[alarm]", 1.0, 2.0) + assert _iou(a, b) == pytest.approx(1.0 / 3.0) + + def test_touching_boundary(self): + # a ends exactly where b starts → inter=0 + a = _det("[alarm]", 0.0, 1.0) + b = _gt("[alarm]", 1.0, 2.0) + assert _iou(a, b) == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# _match +# --------------------------------------------------------------------------- + +class TestMatch: + def test_perfect_single_match(self): + tp, _ = _match([_det("[alarm]", 0.0, 2.0)], [_gt("[alarm]", 0.0, 2.0)], 0.5) + assert tp == 1 + + def test_no_match_below_threshold(self): + # IoU = 1/3 ≈ 0.33 < 0.5 + tp, _ = _match([_det("[alarm]", 0.0, 2.0)], [_gt("[alarm]", 1.0, 3.0)], 0.5) + assert tp == 0 + + def test_match_above_threshold(self): + # IoU = 0.75 ≥ 0.5 + tp, _ = _match([_det("[alarm]", 0.0, 4.0)], [_gt("[alarm]", 1.0, 4.0)], 0.5) + assert tp == 1 + + def test_greedy_avoids_double_matching(self): + # Two detections competing for one GT: only one match allowed + det = [_det("[alarm]", 0.0, 2.0), _det("[alarm]", 0.0, 2.0)] + gt = [_gt("[alarm]", 0.0, 2.0)] + tp, _ = _match(det, gt, 0.5) + assert tp == 1 + + def test_two_distinct_events_both_match(self): + det = [_det("[alarm]", 0.0, 2.0), _det("[gunshot]", 5.0, 6.0)] + # _match is label-filtered upstream, so pass same-label only + det_alarm = [_det("[alarm]", 0.0, 2.0)] + gt_alarm = [_gt("[alarm]", 0.0, 2.0)] + tp, _ = _match(det_alarm, gt_alarm, 0.5) + assert tp == 1 + + +# --------------------------------------------------------------------------- +# evaluate — basic correctness +# --------------------------------------------------------------------------- + +class TestEvaluate: + def test_perfect_detection(self): + detected = [_det("[alarm]", 0.0, 2.0), _det("[gunshot]", 5.0, 6.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0), _gt("[gunshot]", 5.0, 6.0)] + report = evaluate(detected, ground_truth) + assert report.precision == pytest.approx(1.0) + assert report.recall == pytest.approx(1.0) + assert report.f1 == pytest.approx(1.0) + assert report.true_positives == 2 + assert report.false_positives == 0 + assert report.false_negatives == 0 + + def test_all_false_positives(self): + detected = [_det("[alarm]", 0.0, 1.0), _det("[alarm]", 5.0, 6.0)] + ground_truth = [_gt("[alarm]", 20.0, 22.0)] + report = evaluate(detected, ground_truth) + assert report.precision == pytest.approx(0.0) + assert report.recall == pytest.approx(0.0) + assert report.false_positives == 2 + assert report.false_negatives == 1 + + def test_all_false_negatives(self): + ground_truth = [_gt("[alarm]", 0.0, 2.0), _gt("[gunshot]", 5.0, 6.0)] + report = evaluate([], ground_truth) + assert report.precision == pytest.approx(0.0) + assert report.recall == pytest.approx(0.0) + assert report.false_negatives == 2 + assert report.false_positives == 0 + + def test_no_events_at_all(self): + report = evaluate([], []) + assert report.precision == pytest.approx(0.0) + assert report.recall == pytest.approx(0.0) + assert report.f1 == pytest.approx(0.0) + assert report.total_detected == 0 + assert report.total_ground_truth == 0 + + def test_partial_recall(self): + # Two GT events, only one detected + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0), _gt("[alarm]", 10.0, 12.0)] + report = evaluate(detected, ground_truth) + assert report.recall == pytest.approx(0.5) + assert report.precision == pytest.approx(1.0) + + def test_overcaption_rate(self): + # 2 detections, 1 TP and 1 FP → overcaption = 0.5 + detected = [_det("[alarm]", 0.0, 2.0), _det("[alarm]", 50.0, 52.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0)] + report = evaluate(detected, ground_truth) + assert report.overcaption_rate == pytest.approx(0.5) + + def test_overcaption_zero_when_no_detections(self): + report = evaluate([], [_gt("[alarm]", 0.0, 1.0)]) + assert report.overcaption_rate == pytest.approx(0.0) + + def test_f1_is_harmonic_mean(self): + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0), _gt("[alarm]", 10.0, 12.0)] + report = evaluate(detected, ground_truth) + expected_f1 = 2 * report.precision * report.recall / (report.precision + report.recall) + assert report.f1 == pytest.approx(expected_f1, abs=1e-6) + + +# --------------------------------------------------------------------------- +# evaluate — per-label breakdown +# --------------------------------------------------------------------------- + +class TestPerLabelMetrics: + def test_per_label_keys_match_all_labels(self): + detected = [_det("[alarm]", 0.0, 2.0), _det("[gunshot]", 5.0, 6.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0)] + report = evaluate(detected, ground_truth) + assert "[alarm]" in report.per_label + assert "[gunshot]" in report.per_label + + def test_per_label_perfect_match(self): + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0)] + report = evaluate(detected, ground_truth) + m = report.per_label["[alarm]"] + assert m.precision == pytest.approx(1.0) + assert m.recall == pytest.approx(1.0) + assert m.f1 == pytest.approx(1.0) + + def test_per_label_fp_only(self): + detected = [_det("[music]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0)] + report = evaluate(detected, ground_truth) + assert report.per_label["[music]"].false_positives == 1 + assert report.per_label["[alarm]"].false_negatives == 1 + + def test_india_label_tracked_correctly(self): + detected = [_det("[firecrackers]", 3.0, 4.5)] + ground_truth = [_gt("[firecrackers]", 3.0, 4.5)] + report = evaluate(detected, ground_truth) + assert "[firecrackers]" in report.per_label + assert report.per_label["[firecrackers]"].true_positives == 1 + + def test_iou_threshold_respected_per_label(self): + # IoU = 1/3 < 0.5 → no match + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 1.0, 3.0)] + report = evaluate(detected, ground_truth, iou_threshold=0.5) + assert report.per_label["[alarm]"].true_positives == 0 + + +# --------------------------------------------------------------------------- +# evaluate — custom IoU threshold +# --------------------------------------------------------------------------- + +class TestIouThreshold: + def test_loose_threshold_accepts_partial_overlap(self): + # IoU = 1/3 ≈ 0.33 — passes at 0.3, fails at 0.5 + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 1.0, 3.0)] + strict = evaluate(detected, ground_truth, iou_threshold=0.5) + loose = evaluate(detected, ground_truth, iou_threshold=0.3) + assert strict.true_positives == 0 + assert loose.true_positives == 1 + + def test_exact_threshold_boundary(self): + # Perfect overlap → IoU = 1.0, should always match regardless of threshold + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0)] + report = evaluate(detected, ground_truth, iou_threshold=0.99) + assert report.true_positives == 1 + + +# --------------------------------------------------------------------------- +# EvalReport.to_dict +# --------------------------------------------------------------------------- + +class TestEvalReportToDict: + def test_keys_present(self): + report = evaluate([], []) + d = report.to_dict() + for key in ("precision", "recall", "f1", "overcaption_rate", "per_label"): + assert key in d + + def test_values_rounded(self): + detected = [_det("[alarm]", 0.0, 2.0)] + ground_truth = [_gt("[alarm]", 0.0, 2.0), _gt("[alarm]", 10.0, 12.0)] + d = evaluate(detected, ground_truth).to_dict() + # Values should be at most 4 decimal places + for key in ("precision", "recall", "f1"): + assert isinstance(d[key], float) + assert len(str(d[key]).split(".")[-1]) <= 4 + + def test_json_serialisable(self): + import json + report = evaluate( + [_det("[alarm]", 0.0, 2.0)], + [_gt("[alarm]", 0.0, 2.0)], + ) + json.dumps(report.to_dict()) # should not raise + + +# --------------------------------------------------------------------------- +# load_ground_truth +# --------------------------------------------------------------------------- + +class TestLoadGroundTruth: + def test_load_valid_file(self, tmp_path): + data = [ + {"label": "[alarm]", "start_s": 1.0, "end_s": 4.5}, + {"label": "[gunshot]", "start_s": 10.0, "end_s": 11.0}, + ] + p = tmp_path / "gt.json" + p.write_text(json.dumps(data), encoding="utf-8") + events = load_ground_truth(str(p)) + assert len(events) == 2 + assert events[0].label == "[alarm]" + assert events[0].start_s == pytest.approx(1.0) + assert events[1].label == "[gunshot]" + + def test_load_empty_file(self, tmp_path): + p = tmp_path / "gt.json" + p.write_text("[]", encoding="utf-8") + assert load_ground_truth(str(p)) == [] + + def test_string_timestamps_coerced(self, tmp_path): + data = [{"label": "[alarm]", "start_s": "1.5", "end_s": "3.0"}] + p = tmp_path / "gt.json" + p.write_text(json.dumps(data), encoding="utf-8") + events = load_ground_truth(str(p)) + assert events[0].start_s == pytest.approx(1.5)