diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1a5cf45 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +# Runtime dependencies +tensorflow>=2.13.0 +tensorflow-hub>=0.14.0 +soundfile>=0.12.1 +numpy>=1.24.0 +opencv-python>=4.8.0 +mediapipe>=0.10.0 + +# Dev / test +pytest>=7.4.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/audio/__init__.py b/src/audio/__init__.py new file mode 100644 index 0000000..87e6133 --- /dev/null +++ b/src/audio/__init__.py @@ -0,0 +1,3 @@ +from src.audio.detector import AudioEvent, SoundEventDetector + +__all__ = ["AudioEvent", "SoundEventDetector"] diff --git a/src/audio/detector.py b/src/audio/detector.py new file mode 100644 index 0000000..b875eaa --- /dev/null +++ b/src/audio/detector.py @@ -0,0 +1,138 @@ +"""Sound Event Detection using YAMNet (Goal 1). + +Heavy ML imports (tensorflow_hub, soundfile) are deferred to method bodies +so the module is importable without any ML stack installed — useful for +running unit tests and importing constants in lightweight contexts. +""" + +from __future__ import annotations + +import dataclasses + +from src.audio.labels import SPEECH_LABELS, to_cc_label + +# YAMNet produces one score vector per ~0.48 s of audio. +_YAMNET_HOP_S: float = 0.48 + +# Adjacent same-label events closer than this are merged into one. +_MERGE_GAP_S: float = 0.5 + +# Default minimum confidence for an event to be kept. +DEFAULT_CONFIDENCE_THRESHOLD: float = 0.35 + + +@dataclasses.dataclass +class AudioEvent: + """A single detected non-speech audio event.""" + + label: str + start_s: float + end_s: float + confidence: float + + +def _merge_adjacent(events: list[AudioEvent]) -> list[AudioEvent]: + """Merge same-label events whose gap is within *_MERGE_GAP_S*.""" + if not events: + return [] + merged = [events[0]] + for ev in events[1:]: + prev = merged[-1] + if ev.label == prev.label and (ev.start_s - prev.end_s) <= _MERGE_GAP_S: + merged[-1] = dataclasses.replace( + prev, + end_s=ev.end_s, + confidence=max(prev.confidence, ev.confidence), + ) + else: + merged.append(ev) + return merged + + +class SoundEventDetector: + """Detects non-speech audio events in a 16 kHz mono WAV file via YAMNet. + + Usage:: + + detector = SoundEventDetector() + events = detector.detect("extracted_audio.wav") + for ev in events: + print(ev.label, ev.start_s, ev.end_s, ev.confidence) + """ + + def __init__(self, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD) -> None: + self.confidence_threshold = confidence_threshold + self._model = None + self._class_names: list[str] | None = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def detect(self, audio_path: str) -> list[AudioEvent]: + """Run YAMNet on *audio_path* and return filtered, merged events. + + *audio_path* must be a 16 kHz mono WAV file. Use ``ffmpeg`` or + ``librosa`` to resample if the source has a different sample rate. + + Raises + ------ + ValueError + If *audio_path* is not 16 kHz. + """ + import numpy as np + import soundfile as sf + + self._load_model() + + waveform, sr = sf.read(audio_path, dtype="float32", always_2d=False) + if sr != 16000: + raise ValueError( + f"Expected 16 kHz audio, got {sr} Hz. " + "Resample with ffmpeg: ffmpeg -i input.wav -ar 16000 out.wav" + ) + + scores, _, _ = self._model(waveform) + scores = scores.numpy() # shape: (n_frames, 521) + + events: list[AudioEvent] = [] + for frame_idx, frame_scores in enumerate(scores): + top_idx = int(frame_scores.argmax()) + confidence = float(frame_scores[top_idx]) + if confidence < self.confidence_threshold: + continue + class_name = self._class_names[top_idx] + if class_name in SPEECH_LABELS: + continue + label = to_cc_label(class_name) + if label is None: + continue + start_s = frame_idx * _YAMNET_HOP_S + events.append( + AudioEvent( + label=label, + start_s=round(start_s, 3), + end_s=round(start_s + _YAMNET_HOP_S, 3), + confidence=round(confidence, 4), + ) + ) + + return _merge_adjacent(events) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _load_model(self) -> None: + """Load YAMNet from TF-Hub (idempotent).""" + if self._model is not None: + return + import csv + import io + + import tensorflow_hub as hub + + self._model = hub.load("https://tfhub.dev/google/yamnet/1") + class_map_bytes: bytes = self._model.class_map_path().numpy() + reader = csv.DictReader(io.StringIO(class_map_bytes.decode())) + self._class_names = [row["display_name"] for row in reader] diff --git a/src/audio/labels.py b/src/audio/labels.py new file mode 100644 index 0000000..ae0fdf7 --- /dev/null +++ b/src/audio/labels.py @@ -0,0 +1,110 @@ +"""AudioSet label utilities for the CC suggestion pipeline.""" + +from __future__ import annotations + +# YAMNet class names that represent speech — always suppressed. +SPEECH_LABELS: frozenset[str] = frozenset({ + "Speech", + "Male speech, man speaking", + "Female speech, woman speaking", + "Child speech, kid speaking", + "Conversation", + "Narration, monologue", + "Babbling", + "Speech synthesizer", + "Shout", + "Bellow", + "Whoop", + "Yell", + "Children shouting", + "Screaming", + "Whispering", + "Laughter", + "Baby laughter", + "Giggling", + "Snicker", + "Breathing", + "Wheeze", + "Snoring", + "Cough", + "Sneeze", + "Gasp", + "Sigh", +}) + +# Maps YAMNet display names to human-readable CC labels. +# Includes India-specific sounds absent from generic AudioSet mappings. +LABEL_MAP: dict[str, str] = { + "Gunshot, gunfire": "[gunshot]", + "Machine gun": "[gunshot]", + "Explosion": "[explosion]", + "Blowing up": "[explosion]", + "Fire alarm": "[alarm]", + "Alarm": "[alarm]", + "Smoke detector, smoke alarm": "[alarm]", + "Car alarm": "[alarm]", + "Siren": "[siren]", + "Civil defense siren": "[siren]", + "Ambulance (siren)": "[siren]", + "Police car (siren)": "[siren]", + "Glass": "[glass breaking]", + "Breaking": "[glass breaking]", + "Applause": "[applause]", + "Crowd": "[crowd noise]", + "Cheering": "[cheering]", + "Baby cry, infant cry": "[baby crying]", + "Crying, sobbing": "[crying]", + "Dog": "[dog barking]", + "Bark": "[dog barking]", + "Cat": "[cat meowing]", + "Meow": "[cat meowing]", + # India-specific: Diwali crackers map from AudioSet "Fireworks" + "Fireworks": "[firecrackers]", + # India-specific classical/folk percussion + "Tabla": "[tabla]", + "Dhol": "[dhol]", + "Temple bells": "[temple bells]", + "Knock": "[knocking]", + "Telephone": "[phone ringing]", + "Bell": "[bell]", + "Thunder": "[thunder]", + "Rain": "[rain]", + "Wind": "[wind]", + "Traffic noise, roadway noise": "[traffic]", + "Honk": "[honking]", + "Music": "[music]", + "Musical instrument": "[music]", +} + +# High-impact events: audio confidence alone is usually sufficient. +HIGH_IMPACT: frozenset[str] = frozenset({ + "[gunshot]", + "[explosion]", + "[alarm]", + "[siren]", + "[glass breaking]", +}) + +# Ambient events: require strong visual confirmation to avoid overcaptioning. +AMBIENT: frozenset[str] = frozenset({ + "[music]", + "[rain]", + "[wind]", + "[traffic]", +}) + + +def to_cc_label(yamnet_class: str) -> str | None: + """Return a CC label string for *yamnet_class*, or ``None`` for speech. + + Falls back to ``[]`` for unmapped non-speech events. + Matching is case-insensitive and uses substring search so partial YAMNet + class names (e.g. "Bark" matching "Dog bark") resolve correctly. + """ + if yamnet_class in SPEECH_LABELS: + return None + lower = yamnet_class.lower() + for key, label in LABEL_MAP.items(): + if key.lower() in lower or lower in key.lower(): + return label + return f"[{lower}]" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_audio_detector.py b/tests/test_audio_detector.py new file mode 100644 index 0000000..da83d7a --- /dev/null +++ b/tests/test_audio_detector.py @@ -0,0 +1,178 @@ +"""Tests for src.audio.detector — all ML calls are mocked.""" + +from __future__ import annotations + +import sys +from types import ModuleType +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from src.audio.detector import ( + DEFAULT_CONFIDENCE_THRESHOLD, + AudioEvent, + SoundEventDetector, + _merge_adjacent, +) + + +def _mock_soundfile_module(waveform: np.ndarray, sr: int) -> ModuleType: + """Return a fake soundfile module whose read() returns (waveform, sr).""" + sf = ModuleType("soundfile") + sf.read = MagicMock(return_value=(waveform, sr)) # type: ignore[attr-defined] + return sf + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _ev(label: str, start: float, end: float, conf: float = 0.8) -> AudioEvent: + return AudioEvent(label=label, start_s=start, end_s=end, confidence=conf) + + +def _make_detector() -> SoundEventDetector: + return SoundEventDetector(confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD) + + +def _patch_detector(detector: SoundEventDetector, class_names: list[str], scores_array: np.ndarray): + """Inject a fake loaded model into *detector* without hitting TF-Hub.""" + fake_model = MagicMock() + fake_scores = MagicMock() + fake_scores.numpy.return_value = scores_array + fake_model.return_value = (fake_scores, None, None) + detector._model = fake_model + detector._class_names = class_names + + +# --------------------------------------------------------------------------- +# _merge_adjacent +# --------------------------------------------------------------------------- + +class TestMergeAdjacent: + def test_empty_returns_empty(self): + assert _merge_adjacent([]) == [] + + def test_single_event_unchanged(self): + ev = _ev("[alarm]", 0.0, 0.48) + assert _merge_adjacent([ev]) == [ev] + + def test_different_labels_not_merged(self): + events = [_ev("[alarm]", 0.0, 0.48), _ev("[explosion]", 0.5, 0.96)] + result = _merge_adjacent(events) + assert len(result) == 2 + + def test_same_label_within_gap_merged(self): + events = [_ev("[alarm]", 0.0, 0.48), _ev("[alarm]", 0.6, 1.08)] + result = _merge_adjacent(events) + assert len(result) == 1 + assert result[0].start_s == pytest.approx(0.0) + assert result[0].end_s == pytest.approx(1.08) + + def test_same_label_beyond_gap_not_merged(self): + events = [_ev("[alarm]", 0.0, 0.48), _ev("[alarm]", 2.0, 2.48)] + assert len(_merge_adjacent(events)) == 2 + + def test_merge_keeps_max_confidence(self): + events = [_ev("[alarm]", 0.0, 0.48, conf=0.6), _ev("[alarm]", 0.5, 0.96, conf=0.9)] + result = _merge_adjacent(events) + assert result[0].confidence == pytest.approx(0.9) + + def test_chain_of_three_merged(self): + events = [ + _ev("[siren]", 0.0, 0.48), + _ev("[siren]", 0.5, 0.96), + _ev("[siren]", 1.0, 1.44), + ] + result = _merge_adjacent(events) + assert len(result) == 1 + assert result[0].end_s == pytest.approx(1.44) + + +# --------------------------------------------------------------------------- +# SoundEventDetector.detect +# --------------------------------------------------------------------------- + +def _run_detect(detector: SoundEventDetector, waveform: np.ndarray, sr: int) -> list[AudioEvent]: + """Call detector.detect with soundfile mocked at sys.modules level.""" + fake_sf = _mock_soundfile_module(waveform, sr) + with patch.dict(sys.modules, {"soundfile": fake_sf}): + return detector.detect("fake.wav") + + +class TestSoundEventDetectorDetect: + def test_speech_event_suppressed(self): + detector = _make_detector() + # frame 0: "Speech" at 0.9 confidence → suppressed + # frame 1: "Gunshot, gunfire" at 0.8 confidence → kept + _patch_detector( + detector, + class_names=["Speech", "Gunshot, gunfire"], + scores_array=np.array([[0.9, 0.1], [0.1, 0.8]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 1 + assert events[0].label == "[gunshot]" + + def test_low_confidence_filtered(self): + detector = _make_detector() + # confidence 0.2 < threshold 0.35 → no events + _patch_detector( + detector, + class_names=["Explosion"], + scores_array=np.array([[0.2]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert events == [] + + def test_wrong_sample_rate_raises(self): + detector = _make_detector() + detector._model = MagicMock() + detector._class_names = [] + with pytest.raises(ValueError, match="16 kHz"): + _run_detect(detector, np.zeros(22050, dtype="float32"), 22050) + + def test_india_label_preserved(self): + detector = _make_detector() + _patch_detector( + detector, + class_names=["Fireworks"], + scores_array=np.array([[0.85]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert events[0].label == "[firecrackers]" + + def test_adjacent_events_merged(self): + detector = _make_detector() + # Two consecutive frames with the same label → merged into one event + _patch_detector( + detector, + class_names=["Alarm"], + scores_array=np.array([[0.8], [0.75]]), + ) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 1 + assert events[0].confidence == pytest.approx(0.8) + + def test_event_timestamps_are_correct(self): + detector = _make_detector() + # frame index 2 → start_s = 2 * 0.48 = 0.96 + scores = np.zeros((3, 1), dtype="float32") + scores[2, 0] = 0.9 + _patch_detector(detector, class_names=["Explosion"], scores_array=scores) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 1 + assert events[0].start_s == pytest.approx(0.96) + + def test_multiple_different_events(self): + detector = _make_detector() + # frame 0: Gunshot, frame 3: Applause (gap > merge threshold) + scores = np.zeros((4, 2), dtype="float32") + scores[0, 0] = 0.9 # Gunshot + scores[3, 1] = 0.75 # Applause + _patch_detector(detector, class_names=["Gunshot, gunfire", "Applause"], scores_array=scores) + events = _run_detect(detector, np.zeros(16000, dtype="float32"), 16000) + assert len(events) == 2 + labels = {ev.label for ev in events} + assert labels == {"[gunshot]", "[applause]"} diff --git a/tests/test_audio_labels.py b/tests/test_audio_labels.py new file mode 100644 index 0000000..25505fe --- /dev/null +++ b/tests/test_audio_labels.py @@ -0,0 +1,69 @@ +"""Tests for src.audio.labels — no ML dependencies required.""" + +import pytest + +from src.audio.labels import ( + AMBIENT, + HIGH_IMPACT, + SPEECH_LABELS, + to_cc_label, +) + + +class TestSpeechLabels: + def test_common_speech_suppressed(self): + for label in ("Speech", "Male speech, man speaking", "Whispering", "Cough"): + assert label in SPEECH_LABELS + + def test_non_speech_not_in_speech_labels(self): + assert "Gunshot, gunfire" not in SPEECH_LABELS + assert "Explosion" not in SPEECH_LABELS + + +class TestToCCLabel: + def test_speech_returns_none(self): + assert to_cc_label("Speech") is None + assert to_cc_label("Baby laughter") is None + + def test_exact_known_mapping(self): + assert to_cc_label("Gunshot, gunfire") == "[gunshot]" + assert to_cc_label("Explosion") == "[explosion]" + assert to_cc_label("Fire alarm") == "[alarm]" + assert to_cc_label("Applause") == "[applause]" + + def test_india_specific_labels(self): + assert to_cc_label("Fireworks") == "[firecrackers]" + assert to_cc_label("Tabla") == "[tabla]" + assert to_cc_label("Dhol") == "[dhol]" + assert to_cc_label("Temple bells") == "[temple bells]" + + def test_case_insensitive_match(self): + assert to_cc_label("gunshot, gunfire") == "[gunshot]" + assert to_cc_label("EXPLOSION") == "[explosion]" + + def test_substring_match(self): + # "Bark" should match "Dog bark"-style entries + assert to_cc_label("Bark") == "[dog barking]" + + def test_unmapped_non_speech_gets_fallback(self): + label = to_cc_label("Helicopter") + assert label == "[helicopter]" + + def test_fallback_is_lowercased(self): + label = to_cc_label("Chainsaw") + assert label == label.lower() + + +class TestHighImpactAndAmbient: + def test_gunshot_is_high_impact(self): + assert "[gunshot]" in HIGH_IMPACT + assert "[explosion]" in HIGH_IMPACT + assert "[alarm]" in HIGH_IMPACT + + def test_music_is_ambient(self): + assert "[music]" in AMBIENT + assert "[rain]" in AMBIENT + assert "[traffic]" in AMBIENT + + def test_no_overlap(self): + assert HIGH_IMPACT.isdisjoint(AMBIENT)