diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8d7a250 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,27 @@ +name: tests + +on: + push: + branches: ["main", "feat/**", "fix/**", "chore/**"] + pull_request: + branches: ["main"] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: install test dependencies + run: pip install -r requirements-dev.txt + + - name: run tests + run: python -m pytest tests/ -v --tb=short diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..5b1f97b --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1 @@ +pytest>=7.4.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fusion/__init__.py b/src/fusion/__init__.py new file mode 100644 index 0000000..6358b25 --- /dev/null +++ b/src/fusion/__init__.py @@ -0,0 +1,19 @@ +"""CC Decision Engine — Goal 3 of the intelligent CC suggestion pipeline.""" + +from src.fusion.engine import ( + AudioSignal, + CCDecision, + FusionConfig, + VisualSignal, + batch_decide, + decide, +) + +__all__ = [ + "AudioSignal", + "VisualSignal", + "CCDecision", + "FusionConfig", + "decide", + "batch_decide", +] diff --git a/src/fusion/engine.py b/src/fusion/engine.py new file mode 100644 index 0000000..23a9ebf --- /dev/null +++ b/src/fusion/engine.py @@ -0,0 +1,299 @@ +"""CC Decision Engine — Goal 3. + +Combines audio event signals and visual reaction signals into a +CC / no-CC decision. The key design principle is *category-aware fusion*: +the balance between audio and visual evidence depends on what kind of +sound was detected. + + HIGH_IMPACT events (gunshot, explosion, alarm, siren, glass breaking) + → audio confidence alone is usually sufficient; visual reaction is + a bonus that can rescue lower-confidence detections. + + AMBIENT events (music, rain, wind, traffic) + → these sounds play for long stretches without warranting a caption. + Visual reaction must confirm that the sound is actually affecting + the scene before the engine fires. + + GENERAL events (applause, crying, dog barking, crowd, etc.) + → weighted fusion; audio leads but a strong visual reaction can + push borderline detections over the threshold. + +No ML dependencies — all decisions are deterministic arithmetic on the +confidence scores produced by upstream modules. +""" + +from __future__ import annotations + +import dataclasses + +# --------------------------------------------------------------------------- +# Category sets (mirrors src/audio/labels.py — duplicated so this module +# is standalone and does not depend on Goal 1 being merged first). +# --------------------------------------------------------------------------- + +_HIGH_IMPACT: frozenset[str] = frozenset({ + "[gunshot]", + "[explosion]", + "[alarm]", + "[siren]", + "[glass breaking]", +}) + +_AMBIENT: frozenset[str] = frozenset({ + "[music]", + "[rain]", + "[wind]", + "[traffic]", +}) + + +# --------------------------------------------------------------------------- +# Data types +# --------------------------------------------------------------------------- + +@dataclasses.dataclass +class AudioSignal: + """Audio event emitted by the sound event detection module.""" + + label: str + start_s: float + end_s: float + confidence: float + + +@dataclasses.dataclass +class VisualSignal: + """Visual reaction score emitted by the speaker reaction module. + + *reaction_score* is a value in [0, 1] where 0 means no detectable + reaction and 1 means a strong, unambiguous reaction to the audio event. + Pass ``reaction_score=0.0`` when no visual analysis was performed. + """ + + reaction_score: float + + +@dataclasses.dataclass +class CCDecision: + """Output of the fusion engine for a single audio event.""" + + accepted: bool + label: str + start_s: float + end_s: float + audio_confidence: float + reaction_score: float + combined_score: float + # Plain-English explanation of why the event was accepted or rejected. + reason: str + + +@dataclasses.dataclass +class FusionConfig: + """Threshold configuration for the decision engine. + + All values are in [0, 1]. The defaults are tuned to keep precision + high at the cost of some recall — better to miss a marginal caption + than to flood a video with ambient-noise captions. + """ + + # HIGH_IMPACT: audio weight, visual weight, minimum combined score. + high_impact_audio_w: float = 0.80 + high_impact_visual_w: float = 0.20 + high_impact_min_score: float = 0.40 + + # AMBIENT: visual weight is dominant; also gate on minimum reaction. + ambient_audio_w: float = 0.35 + ambient_visual_w: float = 0.65 + ambient_min_reaction: float = 0.35 + ambient_min_score: float = 0.55 + + # GENERAL: audio leads but visual reaction can tip borderline events. + general_audio_w: float = 0.60 + general_visual_w: float = 0.40 + general_min_score: float = 0.45 + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +_DEFAULT_CONFIG = FusionConfig() + + +def decide( + audio: AudioSignal, + visual: VisualSignal, + config: FusionConfig = _DEFAULT_CONFIG, +) -> CCDecision: + """Return a CC/no-CC decision for one audio event. + + Parameters + ---------- + audio: + Event from the sound event detection module. + visual: + Reaction score from the speaker reaction module. Use + ``VisualSignal(reaction_score=0.0)`` when no visual analysis + was performed. + config: + Threshold configuration. Defaults are tuned for high precision. + """ + category = _category(audio.label) + + if category == "HIGH_IMPACT": + return _decide_high_impact(audio, visual, config) + if category == "AMBIENT": + return _decide_ambient(audio, visual, config) + return _decide_general(audio, visual, config) + + +def batch_decide( + pairs: list[tuple[AudioSignal, VisualSignal]], + config: FusionConfig = _DEFAULT_CONFIG, +) -> list[CCDecision]: + """Run :func:`decide` over a list of (audio, visual) pairs.""" + return [decide(a, v, config) for a, v in pairs] + + +# --------------------------------------------------------------------------- +# Category routing +# --------------------------------------------------------------------------- + +def _category(label: str) -> str: + if label in _HIGH_IMPACT: + return "HIGH_IMPACT" + if label in _AMBIENT: + return "AMBIENT" + return "GENERAL" + + +# --------------------------------------------------------------------------- +# Per-category decision logic +# --------------------------------------------------------------------------- + +def _decide_high_impact( + audio: AudioSignal, visual: VisualSignal, cfg: FusionConfig +) -> CCDecision: + combined = ( + audio.confidence * cfg.high_impact_audio_w + + visual.reaction_score * cfg.high_impact_visual_w + ) + combined = round(combined, 4) + accepted = combined >= cfg.high_impact_min_score + + if accepted: + reason = ( + f"high-impact event: combined score {combined:.2f} " + f"(audio {audio.confidence:.2f} × {cfg.high_impact_audio_w} " + f"+ reaction {visual.reaction_score:.2f} × {cfg.high_impact_visual_w})" + ) + else: + reason = ( + f"high-impact event below threshold: combined {combined:.2f} " + f"< {cfg.high_impact_min_score} " + f"(audio confidence {audio.confidence:.2f} too low)" + ) + + return CCDecision( + accepted=accepted, + label=audio.label, + start_s=audio.start_s, + end_s=audio.end_s, + audio_confidence=audio.confidence, + reaction_score=visual.reaction_score, + combined_score=combined, + reason=reason, + ) + + +def _decide_ambient( + audio: AudioSignal, visual: VisualSignal, cfg: FusionConfig +) -> CCDecision: + # Gate first: ambient sounds require a minimum visible reaction. + if visual.reaction_score < cfg.ambient_min_reaction: + combined = round( + audio.confidence * cfg.ambient_audio_w + + visual.reaction_score * cfg.ambient_visual_w, + 4, + ) + return CCDecision( + accepted=False, + label=audio.label, + start_s=audio.start_s, + end_s=audio.end_s, + audio_confidence=audio.confidence, + reaction_score=visual.reaction_score, + combined_score=combined, + reason=( + f"ambient sound rejected: reaction score {visual.reaction_score:.2f} " + f"below minimum {cfg.ambient_min_reaction} " + "(no visible scene response — likely background noise)" + ), + ) + + combined = round( + audio.confidence * cfg.ambient_audio_w + + visual.reaction_score * cfg.ambient_visual_w, + 4, + ) + accepted = combined >= cfg.ambient_min_score + + if accepted: + reason = ( + f"ambient event confirmed by visual reaction: combined {combined:.2f} " + f"(audio {audio.confidence:.2f} × {cfg.ambient_audio_w} " + f"+ reaction {visual.reaction_score:.2f} × {cfg.ambient_visual_w})" + ) + else: + reason = ( + f"ambient event: combined score {combined:.2f} " + f"below threshold {cfg.ambient_min_score} " + f"despite reaction {visual.reaction_score:.2f}" + ) + + return CCDecision( + accepted=accepted, + label=audio.label, + start_s=audio.start_s, + end_s=audio.end_s, + audio_confidence=audio.confidence, + reaction_score=visual.reaction_score, + combined_score=combined, + reason=reason, + ) + + +def _decide_general( + audio: AudioSignal, visual: VisualSignal, cfg: FusionConfig +) -> CCDecision: + combined = round( + audio.confidence * cfg.general_audio_w + + visual.reaction_score * cfg.general_visual_w, + 4, + ) + accepted = combined >= cfg.general_min_score + + if accepted: + reason = ( + f"accepted: combined score {combined:.2f} " + f"(audio {audio.confidence:.2f} × {cfg.general_audio_w} " + f"+ reaction {visual.reaction_score:.2f} × {cfg.general_visual_w})" + ) + else: + reason = ( + f"rejected: combined score {combined:.2f} " + f"below threshold {cfg.general_min_score} " + f"(audio {audio.confidence:.2f}, reaction {visual.reaction_score:.2f})" + ) + + return CCDecision( + accepted=accepted, + label=audio.label, + start_s=audio.start_s, + end_s=audio.end_s, + audio_confidence=audio.confidence, + reaction_score=visual.reaction_score, + combined_score=combined, + reason=reason, + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_fusion_engine.py b/tests/test_fusion_engine.py new file mode 100644 index 0000000..77f9fc7 --- /dev/null +++ b/tests/test_fusion_engine.py @@ -0,0 +1,290 @@ +"""Tests for src.fusion.engine — no ML dependencies required.""" + +from __future__ import annotations + +import pytest + +from src.fusion.engine import ( + AudioSignal, + CCDecision, + FusionConfig, + VisualSignal, + _category, + batch_decide, + decide, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _audio(label: str, conf: float, start: float = 0.0, end: float = 1.0) -> AudioSignal: + return AudioSignal(label=label, start_s=start, end_s=end, confidence=conf) + + +def _visual(score: float) -> VisualSignal: + return VisualSignal(reaction_score=score) + + +# --------------------------------------------------------------------------- +# _category +# --------------------------------------------------------------------------- + +class TestCategory: + def test_high_impact_labels(self): + for label in ("[gunshot]", "[explosion]", "[alarm]", "[siren]", "[glass breaking]"): + assert _category(label) == "HIGH_IMPACT" + + def test_ambient_labels(self): + for label in ("[music]", "[rain]", "[wind]", "[traffic]"): + assert _category(label) == "AMBIENT" + + def test_general_fallback(self): + for label in ("[applause]", "[crying]", "[dog barking]", "[crowd noise]", "[tabla]"): + assert _category(label) == "GENERAL" + + def test_unknown_label_is_general(self): + assert _category("[helicopter]") == "GENERAL" + assert _category("[firecrackers]") == "GENERAL" + + +# --------------------------------------------------------------------------- +# HIGH_IMPACT decisions +# --------------------------------------------------------------------------- + +class TestHighImpact: + def test_strong_audio_no_visual_accepted(self): + # 0.8 * 0.80 + 0.0 * 0.20 = 0.64 >= 0.40 + result = decide(_audio("[gunshot]", 0.8), _visual(0.0)) + assert result.accepted is True + + def test_minimum_audio_exactly_at_threshold(self): + # Need combined >= 0.40. With visual=0: audio >= 0.40 / 0.80 = 0.50 + result = decide(_audio("[alarm]", 0.50), _visual(0.0)) + assert result.accepted is True + + def test_low_audio_with_strong_visual_accepted(self): + # audio=0.30 → 0.30*0.80 + 0.90*0.20 = 0.24 + 0.18 = 0.42 >= 0.40 + result = decide(_audio("[explosion]", 0.30), _visual(0.90)) + assert result.accepted is True + + def test_very_low_audio_no_visual_rejected(self): + # 0.20 * 0.80 + 0.0 = 0.16 < 0.40 + result = decide(_audio("[siren]", 0.20), _visual(0.0)) + assert result.accepted is False + + def test_reason_contains_combined_score(self): + result = decide(_audio("[gunshot]", 0.9), _visual(0.0)) + assert "combined" in result.reason.lower() or "high-impact" in result.reason.lower() + + def test_timestamps_preserved(self): + result = decide(_audio("[alarm]", 0.9, start=4.32, end=6.72), _visual(0.5)) + assert result.start_s == pytest.approx(4.32) + assert result.end_s == pytest.approx(6.72) + + def test_label_preserved(self): + result = decide(_audio("[glass breaking]", 0.8), _visual(0.0)) + assert result.label == "[glass breaking]" + + def test_combined_score_is_weighted_average(self): + cfg = FusionConfig() + a, v = 0.70, 0.50 + expected = round(a * cfg.high_impact_audio_w + v * cfg.high_impact_visual_w, 4) + result = decide(_audio("[alarm]", a), _visual(v), cfg) + assert result.combined_score == pytest.approx(expected, abs=1e-4) + + +# --------------------------------------------------------------------------- +# AMBIENT decisions +# --------------------------------------------------------------------------- + +class TestAmbient: + def test_strong_audio_no_visual_rejected(self): + # Music always playing, no scene reaction → reject to avoid overcaptioning + result = decide(_audio("[music]", 0.95), _visual(0.0)) + assert result.accepted is False + + def test_reaction_below_minimum_gate_rejected(self): + # reaction_score 0.20 < ambient_min_reaction 0.35 → gated out immediately + result = decide(_audio("[rain]", 0.90), _visual(0.20)) + assert result.accepted is False + + def test_reason_mentions_background_noise_on_gate(self): + result = decide(_audio("[music]", 0.90), _visual(0.10)) + assert "background" in result.reason.lower() or "reaction" in result.reason.lower() + + def test_strong_audio_strong_visual_accepted(self): + # 0.90*0.35 + 0.80*0.65 = 0.315 + 0.52 = 0.835 >= 0.55 + result = decide(_audio("[music]", 0.90), _visual(0.80)) + assert result.accepted is True + + def test_just_below_combined_threshold_rejected(self): + # reaction=0.40 clears the gate (>=0.35), but combined might still fail + # 0.50*0.35 + 0.40*0.65 = 0.175 + 0.26 = 0.435 < 0.55 + result = decide(_audio("[wind]", 0.50), _visual(0.40)) + assert result.accepted is False + + def test_exact_combined_threshold_accepted(self): + # Build audio+visual so combined == exactly ambient_min_score (0.55) + cfg = FusionConfig() + # 0.55 = audio * 0.35 + visual * 0.65 with visual = cfg.ambient_min_reaction + # Let visual = 0.70: combined = audio*0.35 + 0.455; need audio*0.35 = 0.095 → audio = 0.271 + audio_conf = (cfg.ambient_min_score - 0.70 * cfg.ambient_visual_w) / cfg.ambient_audio_w + result = decide(_audio("[traffic]", audio_conf), _visual(0.70), cfg) + assert result.combined_score == pytest.approx(cfg.ambient_min_score, abs=0.01) + + def test_india_ambient_tabla_treated_as_general_not_ambient(self): + # [tabla] is not in AMBIENT set — should use GENERAL rules not AMBIENT rules + result_tabla = decide(_audio("[tabla]", 0.80), _visual(0.0)) + result_music = decide(_audio("[music]", 0.80), _visual(0.0)) + # tabla can be accepted (GENERAL), music cannot without visual (AMBIENT) + assert result_music.accepted is False + # tabla with 0.80 audio: 0.80*0.60 + 0.0*0.40 = 0.48 >= 0.45 → accepted + assert result_tabla.accepted is True + + +# --------------------------------------------------------------------------- +# GENERAL decisions +# --------------------------------------------------------------------------- + +class TestGeneral: + def test_strong_audio_no_visual_accepted(self): + # 0.80*0.60 + 0.0*0.40 = 0.48 >= 0.45 + result = decide(_audio("[applause]", 0.80), _visual(0.0)) + assert result.accepted is True + + def test_weak_audio_no_visual_rejected(self): + # 0.40*0.60 + 0.0 = 0.24 < 0.45 + result = decide(_audio("[dog barking]", 0.40), _visual(0.0)) + assert result.accepted is False + + def test_borderline_audio_strong_visual_accepted(self): + # audio=0.40, visual=0.60: 0.40*0.60 + 0.60*0.40 = 0.24 + 0.24 = 0.48 >= 0.45 + result = decide(_audio("[crying]", 0.40), _visual(0.60)) + assert result.accepted is True + + def test_reason_contains_threshold_on_rejection(self): + result = decide(_audio("[crowd noise]", 0.30), _visual(0.10)) + assert result.accepted is False + assert "rejected" in result.reason.lower() or "threshold" in result.reason.lower() + + def test_india_label_tabla_accepted_at_high_confidence(self): + result = decide(_audio("[tabla]", 0.85), _visual(0.0)) + assert result.accepted is True + + def test_india_label_firecrackers_accepted_with_reaction(self): + # [firecrackers] is GENERAL (not in HIGH_IMPACT despite being explosive in nature) + result = decide(_audio("[firecrackers]", 0.60), _visual(0.50)) + # 0.60*0.60 + 0.50*0.40 = 0.36 + 0.20 = 0.56 >= 0.45 → accepted + assert result.accepted is True + + def test_combined_score_computed_correctly(self): + cfg = FusionConfig() + a, v = 0.65, 0.55 + expected = round(a * cfg.general_audio_w + v * cfg.general_visual_w, 4) + result = decide(_audio("[cheering]", a), _visual(v), cfg) + assert result.combined_score == pytest.approx(expected, abs=1e-4) + + +# --------------------------------------------------------------------------- +# CCDecision fields +# --------------------------------------------------------------------------- + +class TestCCDecisionFields: + def test_all_fields_populated(self): + result = decide(_audio("[alarm]", 0.75), _visual(0.50)) + assert isinstance(result.accepted, bool) + assert isinstance(result.label, str) + assert isinstance(result.start_s, float) + assert isinstance(result.end_s, float) + assert isinstance(result.audio_confidence, float) + assert isinstance(result.reaction_score, float) + assert isinstance(result.combined_score, float) + assert isinstance(result.reason, str) + assert len(result.reason) > 0 + + def test_audio_confidence_preserved(self): + result = decide(_audio("[gunshot]", 0.73), _visual(0.60)) + assert result.audio_confidence == pytest.approx(0.73) + + def test_reaction_score_preserved(self): + result = decide(_audio("[gunshot]", 0.80), _visual(0.45)) + assert result.reaction_score == pytest.approx(0.45) + + def test_combined_score_range(self): + for conf in (0.0, 0.35, 0.70, 1.0): + for react in (0.0, 0.50, 1.0): + result = decide(_audio("[alarm]", conf), _visual(react)) + assert 0.0 <= result.combined_score <= 1.0 + + +# --------------------------------------------------------------------------- +# batch_decide +# --------------------------------------------------------------------------- + +class TestBatchDecide: + def test_empty_list(self): + assert batch_decide([]) == [] + + def test_single_pair(self): + result = batch_decide([(_audio("[alarm]", 0.9), _visual(0.0))]) + assert len(result) == 1 + assert isinstance(result[0], CCDecision) + + def test_order_preserved(self): + pairs = [ + (_audio("[gunshot]", 0.9), _visual(0.0)), + (_audio("[music]", 0.9), _visual(0.0)), + (_audio("[applause]", 0.8), _visual(0.0)), + ] + results = batch_decide(pairs) + assert results[0].label == "[gunshot]" + assert results[1].label == "[music]" + assert results[2].label == "[applause]" + + def test_mixed_accept_reject(self): + pairs = [ + (_audio("[alarm]", 0.9), _visual(0.0)), # HIGH_IMPACT → accepted + (_audio("[music]", 0.9), _visual(0.0)), # AMBIENT no visual → rejected + (_audio("[applause]", 0.8), _visual(0.0)), # GENERAL strong audio → accepted + ] + results = batch_decide(pairs) + assert results[0].accepted is True + assert results[1].accepted is False + assert results[2].accepted is True + + def test_custom_config_propagated(self): + strict = FusionConfig(high_impact_min_score=0.99) + result = batch_decide([(_audio("[alarm]", 0.9), _visual(0.0))], config=strict) + assert result[0].accepted is False + + +# --------------------------------------------------------------------------- +# Custom FusionConfig +# --------------------------------------------------------------------------- + +class TestFusionConfig: + def test_stricter_threshold_rejects_borderline(self): + default_result = decide(_audio("[alarm]", 0.50), _visual(0.0)) + strict = FusionConfig(high_impact_min_score=0.50) + strict_result = decide(_audio("[alarm]", 0.50), _visual(0.0), strict) + # Default (0.40) accepts: 0.50*0.80 = 0.40 >= 0.40 + assert default_result.accepted is True + # Strict (0.50): 0.50*0.80 = 0.40 < 0.50 + assert strict_result.accepted is False + + def test_looser_ambient_reaction_gate(self): + loose = FusionConfig(ambient_min_reaction=0.10, ambient_min_score=0.40) + # With default, reaction=0.15 is gated out; with loose config, it passes the gate + default_result = decide(_audio("[music]", 0.90), _visual(0.15)) + loose_result = decide(_audio("[music]", 0.90), _visual(0.15), loose) + assert default_result.accepted is False + # loose: gate passes (0.15 >= 0.10), combined = 0.90*0.35 + 0.15*0.65 = 0.315+0.0975=0.4125 >= 0.40 + assert loose_result.accepted is True + + def test_config_weights_sum_check(self): + cfg = FusionConfig() + assert cfg.high_impact_audio_w + cfg.high_impact_visual_w == pytest.approx(1.0) + assert cfg.ambient_audio_w + cfg.ambient_visual_w == pytest.approx(1.0) + assert cfg.general_audio_w + cfg.general_visual_w == pytest.approx(1.0)