Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Runtime dependencies
tensorflow>=2.13.0
tensorflow-hub>=0.14.0
soundfile>=0.12.1
numpy>=1.24.0
opencv-python>=4.8.0
mediapipe>=0.10.0

# Dev / test
pytest>=7.4.0
Empty file added src/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions src/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.audio.detector import AudioEvent, SoundEventDetector

__all__ = ["AudioEvent", "SoundEventDetector"]
138 changes: 138 additions & 0 deletions src/audio/detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Sound Event Detection using YAMNet (Goal 1).

Heavy ML imports (tensorflow_hub, soundfile) are deferred to method bodies
so the module is importable without any ML stack installed — useful for
running unit tests and importing constants in lightweight contexts.
"""

from __future__ import annotations

import dataclasses

from src.audio.labels import SPEECH_LABELS, to_cc_label

# YAMNet produces one score vector per ~0.48 s of audio.
_YAMNET_HOP_S: float = 0.48

# Adjacent same-label events closer than this are merged into one.
_MERGE_GAP_S: float = 0.5

# Default minimum confidence for an event to be kept.
DEFAULT_CONFIDENCE_THRESHOLD: float = 0.35


@dataclasses.dataclass
class AudioEvent:
"""A single detected non-speech audio event."""

label: str
start_s: float
end_s: float
confidence: float


def _merge_adjacent(events: list[AudioEvent]) -> list[AudioEvent]:
"""Merge same-label events whose gap is within *_MERGE_GAP_S*."""
if not events:
return []
merged = [events[0]]
for ev in events[1:]:
prev = merged[-1]
if ev.label == prev.label and (ev.start_s - prev.end_s) <= _MERGE_GAP_S:
merged[-1] = dataclasses.replace(
prev,
end_s=ev.end_s,
confidence=max(prev.confidence, ev.confidence),
)
else:
merged.append(ev)
return merged


class SoundEventDetector:
"""Detects non-speech audio events in a 16 kHz mono WAV file via YAMNet.

Usage::

detector = SoundEventDetector()
events = detector.detect("extracted_audio.wav")
for ev in events:
print(ev.label, ev.start_s, ev.end_s, ev.confidence)
"""

def __init__(self, confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD) -> None:
self.confidence_threshold = confidence_threshold
self._model = None
self._class_names: list[str] | None = None

# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------

def detect(self, audio_path: str) -> list[AudioEvent]:
"""Run YAMNet on *audio_path* and return filtered, merged events.

*audio_path* must be a 16 kHz mono WAV file. Use ``ffmpeg`` or
``librosa`` to resample if the source has a different sample rate.

Raises
------
ValueError
If *audio_path* is not 16 kHz.
"""
import numpy as np
import soundfile as sf

self._load_model()

waveform, sr = sf.read(audio_path, dtype="float32", always_2d=False)
if sr != 16000:
raise ValueError(
f"Expected 16 kHz audio, got {sr} Hz. "
"Resample with ffmpeg: ffmpeg -i input.wav -ar 16000 out.wav"
)

scores, _, _ = self._model(waveform)
scores = scores.numpy() # shape: (n_frames, 521)

events: list[AudioEvent] = []
for frame_idx, frame_scores in enumerate(scores):
top_idx = int(frame_scores.argmax())
confidence = float(frame_scores[top_idx])
if confidence < self.confidence_threshold:
continue
class_name = self._class_names[top_idx]
if class_name in SPEECH_LABELS:
continue
label = to_cc_label(class_name)
if label is None:
continue
start_s = frame_idx * _YAMNET_HOP_S
events.append(
AudioEvent(
label=label,
start_s=round(start_s, 3),
end_s=round(start_s + _YAMNET_HOP_S, 3),
confidence=round(confidence, 4),
)
)

return _merge_adjacent(events)

# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------

def _load_model(self) -> None:
"""Load YAMNet from TF-Hub (idempotent)."""
if self._model is not None:
return
import csv
import io

import tensorflow_hub as hub

self._model = hub.load("https://tfhub.dev/google/yamnet/1")
class_map_bytes: bytes = self._model.class_map_path().numpy()
reader = csv.DictReader(io.StringIO(class_map_bytes.decode()))
self._class_names = [row["display_name"] for row in reader]
110 changes: 110 additions & 0 deletions src/audio/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""AudioSet label utilities for the CC suggestion pipeline."""

from __future__ import annotations

# YAMNet class names that represent speech — always suppressed.
SPEECH_LABELS: frozenset[str] = frozenset({
"Speech",
"Male speech, man speaking",
"Female speech, woman speaking",
"Child speech, kid speaking",
"Conversation",
"Narration, monologue",
"Babbling",
"Speech synthesizer",
"Shout",
"Bellow",
"Whoop",
"Yell",
"Children shouting",
"Screaming",
"Whispering",
"Laughter",
"Baby laughter",
"Giggling",
"Snicker",
"Breathing",
"Wheeze",
"Snoring",
"Cough",
"Sneeze",
"Gasp",
"Sigh",
})

# Maps YAMNet display names to human-readable CC labels.
# Includes India-specific sounds absent from generic AudioSet mappings.
LABEL_MAP: dict[str, str] = {
"Gunshot, gunfire": "[gunshot]",
"Machine gun": "[gunshot]",
"Explosion": "[explosion]",
"Blowing up": "[explosion]",
"Fire alarm": "[alarm]",
"Alarm": "[alarm]",
"Smoke detector, smoke alarm": "[alarm]",
"Car alarm": "[alarm]",
"Siren": "[siren]",
"Civil defense siren": "[siren]",
"Ambulance (siren)": "[siren]",
"Police car (siren)": "[siren]",
"Glass": "[glass breaking]",
"Breaking": "[glass breaking]",
"Applause": "[applause]",
"Crowd": "[crowd noise]",
"Cheering": "[cheering]",
"Baby cry, infant cry": "[baby crying]",
"Crying, sobbing": "[crying]",
"Dog": "[dog barking]",
"Bark": "[dog barking]",
"Cat": "[cat meowing]",
"Meow": "[cat meowing]",
# India-specific: Diwali crackers map from AudioSet "Fireworks"
"Fireworks": "[firecrackers]",
# India-specific classical/folk percussion
"Tabla": "[tabla]",
"Dhol": "[dhol]",
"Temple bells": "[temple bells]",
"Knock": "[knocking]",
"Telephone": "[phone ringing]",
"Bell": "[bell]",
"Thunder": "[thunder]",
"Rain": "[rain]",
"Wind": "[wind]",
"Traffic noise, roadway noise": "[traffic]",
"Honk": "[honking]",
"Music": "[music]",
"Musical instrument": "[music]",
}

# High-impact events: audio confidence alone is usually sufficient.
HIGH_IMPACT: frozenset[str] = frozenset({
"[gunshot]",
"[explosion]",
"[alarm]",
"[siren]",
"[glass breaking]",
})

# Ambient events: require strong visual confirmation to avoid overcaptioning.
AMBIENT: frozenset[str] = frozenset({
"[music]",
"[rain]",
"[wind]",
"[traffic]",
})


def to_cc_label(yamnet_class: str) -> str | None:
"""Return a CC label string for *yamnet_class*, or ``None`` for speech.

Falls back to ``[<lowercased class name>]`` for unmapped non-speech events.
Matching is case-insensitive and uses substring search so partial YAMNet
class names (e.g. "Bark" matching "Dog bark") resolve correctly.
"""
if yamnet_class in SPEECH_LABELS:
return None
lower = yamnet_class.lower()
for key, label in LABEL_MAP.items():
if key.lower() in lower or lower in key.lower():
return label
return f"[{lower}]"
Empty file added tests/__init__.py
Empty file.
Loading