diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c951c91 --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# Python +__pycache__/ +*.py[cod] +*.pyo +*.pyd +.Python +*.egg +*.egg-info/ +dist/ +build/ + +# Virtual environments +.venv/ +venv/ +env/ + +# Project outputs +outputs/ + +# Large model files +*.h5 +*.pb +*.tflite +*.onnx + +# Secrets and Environment +.env +.env.local +.env.*.local diff --git a/README.md b/README.md new file mode 100644 index 0000000..0f5b85f --- /dev/null +++ b/README.md @@ -0,0 +1,69 @@ +# Intelligent CC Suggestion Tool — Goal 1 + +## Sound Event Detection Module +This module automatically detects and classifies non-speech audio events (like honking, laughter, music) from a video file. + +--- + +## 🛠 Setup Instructions + +### 1. Prerequisites +- **Python 3.10+** +- **FFmpeg**: Must be installed and available in your system's PATH. + - *Windows*: `winget install ffmpeg` + - *Linux*: `sudo apt install ffmpeg` + +### 2. Environment Setup +Clone the repository and set up a virtual environment: + +```bash +# Clone the repository +git clone https://github.com/Siddharth-732/Intelligent-cc-generation.git +cd Intelligent-cc-generation + +# Create a virtual environment +python -m venv .venv + +# Activate the environment +# Windows: +.\.venv\Scripts\activate +# Linux/Mac: +source .venv/bin/activate + +# Install dependencies +pip install -r requirements.txt +``` + +### 3. Usage +Verify the installation by running the test script on a video file: + +```bash +python test_goal_1.py "path/to/your/video.mp4" +``` + +--- + +## 💻 Programmatic Usage +```python +from cc_tool.audio import extract_audio, SoundEventDetector + +# 1. Extract audio from video +wav_path = extract_audio("video.mp4") + +# 2. Initialize detector +detector = SoundEventDetector(confidence_threshold=0.3) + +# 3. Detect non-speech events +events = detector.detect(wav_path) + +for e in events: + print(f"[{e.start_sec}s - {e.end_sec}s] {e.label} ({e.confidence})") +``` + +--- + +## 📁 Project Structure +- `cc_tool/audio/extractor.py`: Audio extraction logic. +- `cc_tool/audio/detector.py`: YAMNet model implementation. +- `cc_tool/audio/models.py`: Data models. +- `cc_tool/audio/utils.py`: Audio processing utilities. diff --git a/cc_tool/__init__.py b/cc_tool/__init__.py new file mode 100644 index 0000000..a5d59a1 --- /dev/null +++ b/cc_tool/__init__.py @@ -0,0 +1 @@ +# Marks this directory as a Python package diff --git a/cc_tool/audio/__init__.py b/cc_tool/audio/__init__.py new file mode 100644 index 0000000..e07ba48 --- /dev/null +++ b/cc_tool/audio/__init__.py @@ -0,0 +1,5 @@ +from .detector import SoundEventDetector +from .extractor import extract_audio +from .models import AudioEvent + +__all__ = ["SoundEventDetector", "extract_audio", "AudioEvent"] diff --git a/cc_tool/audio/detector.py b/cc_tool/audio/detector.py new file mode 100644 index 0000000..c67e4f3 --- /dev/null +++ b/cc_tool/audio/detector.py @@ -0,0 +1,82 @@ +import numpy as np +import soundfile as sf +import tensorflow_hub as hub +import csv +import os +from cc_tool.audio.models import AudioEvent +from cc_tool.audio.utils import chunk_audio, normalize_waveform +from cc_tool.audio.mapping import get_canonical_label, _IGNORE_GROUPS + +# AudioSet indices for speech - we ignore these +SPEECH_INDICES = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +class SoundEventDetector: + def __init__(self, confidence_threshold=0.25): # Lowered threshold slightly for grouped detection + self.confidence_threshold = confidence_threshold + self._model = None + self._class_names = [] + + def _load_model(self): + if self._model is None: + self._model = hub.load("https://tfhub.dev/google/yamnet/1") + class_map_path = self._model.class_map_path().numpy().decode() + if os.path.exists(class_map_path): + with open(class_map_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + self._class_names = [row["display_name"] for row in reader] + else: + import urllib.request + with urllib.request.urlopen(class_map_path) as f: + reader = csv.DictReader(line.decode("utf-8") for line in f) + self._class_names = [row["display_name"] for row in reader] + + def detect(self, wav_path): + self._load_model() + waveform, sr = sf.read(wav_path, dtype="float32") + if waveform.ndim > 1: waveform = waveform.mean(axis=1) + waveform = normalize_waveform(waveform) + + chunks = chunk_audio(waveform, sr) + raw_events = [] + + for start_sec, end_sec, chunk in chunks: + scores, _, _ = self._model(chunk) + mean_scores = scores.numpy().mean(axis=0) + + # Multi-label grouped scoring: + # Sum scores for each canonical group across all 521 YAMNet classes. + group_scores = {} + for idx, score in enumerate(mean_scores): + if idx in SPEECH_INDICES: continue + canonical = get_canonical_label(self._class_names[idx]) + group_scores[canonical] = group_scores.get(canonical, 0) + score + + # Emit EVERY group that clears the threshold — not just the loudest one. + # This is the key fix: co-occurring sounds (e.g., crowd + snake hiss) + # were previously suppressed because only max() was kept per chunk. + # Skip _ambient_ — it's the catch-all for unmapped background classes. + for group, score in group_scores.items(): + if group in _IGNORE_GROUPS: + continue + if score >= self.confidence_threshold: + raw_events.append(AudioEvent( + label=group, + confidence=float(score), + start_sec=start_sec, + end_sec=end_sec + )) + + return self._merge_events(raw_events) + + def _merge_events(self, events): + if not events: return [] + events.sort(key=lambda x: x.start_sec) + merged = [events[0]] + for curr in events[1:]: + prev = merged[-1] + if curr.label == prev.label and curr.start_sec <= prev.end_sec: + prev.end_sec = max(prev.end_sec, curr.end_sec) + prev.confidence = max(prev.confidence, curr.confidence) + else: + merged.append(curr) + return merged \ No newline at end of file diff --git a/cc_tool/audio/extractor.py b/cc_tool/audio/extractor.py new file mode 100644 index 0000000..77be265 --- /dev/null +++ b/cc_tool/audio/extractor.py @@ -0,0 +1,35 @@ +import os +import subprocess +from pathlib import Path + +def extract_audio(video_path: str, output_wav: str | None = None) -> str: + """ + Extract audio track from video using FFmpeg. + Assumes 'ffmpeg' is available in the system PATH. + """ + video_path = Path(video_path) + if not video_path.exists(): + raise FileNotFoundError(f"Video file not found: {video_path}") + + if output_wav is None: + os.makedirs("outputs/audio", exist_ok=True) + output_wav = f"outputs/audio/{video_path.stem}.wav" + + # Standard command using system 'ffmpeg' + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-vn", + "-acodec", "pcm_s16le", + "-ar", "16000", + "-ac", "1", + str(output_wav) + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"FFmpeg failed. Ensure FFmpeg is installed and in your PATH.\nError: {result.stderr}" + ) + + return str(output_wav) diff --git a/cc_tool/audio/mapping.py b/cc_tool/audio/mapping.py new file mode 100644 index 0000000..bac2cf1 --- /dev/null +++ b/cc_tool/audio/mapping.py @@ -0,0 +1,147 @@ +""" +Semantic mapping for YAMNet classes to canonical CC labels. +Groups 521 specific classes into ~20 stable categories. +""" + +# Map of YAMNet Display Names → Canonical CC Labels. +# Keys must match the 'display_name' field from YAMNet's class map CSV exactly. +LABEL_GROUPS = { + # --- Impacts & Explosions --- + "Explosion": "[Gunshot/Explosion]", + "Gunshot, gunfire": "[Gunshot/Explosion]", + "Machine gun": "[Gunshot/Explosion]", + "Fusillade": "[Gunshot/Explosion]", + "Firecracker": "[Gunshot/Explosion]", + "Fireworks": "[Gunshot/Explosion]", + "Artillery fire": "[Gunshot/Explosion]", + "Cap gun": "[Gunshot/Explosion]", + "Burst, pop": "[Impact/Pop]", + "Boom": "[Impact/Pop]", + "Thud": "[Impact/Pop]", + "Slam": "[Impact/Pop]", + "Hammer": "[Metallic Impact]", + "Clang": "[Metallic Impact]", + "Clatter": "[Metallic Impact]", + "Dishes, pots, and pans": "[Metallic Impact]", + "Cutlery, silverware": "[Metallic Impact]", + "Glass": "[Glass Breaking]", + "Shatter": "[Glass Breaking]", + "Breaking": "[Glass Breaking]", + "Chink and clink": "[Glass Breaking]", + "Ding-dong": "[Bell/Chime]", + "Bell": "[Bell/Chime]", + "Church bell": "[Bell/Chime]", + "Cowbell": "[Bell/Chime]", + + # --- Vehicles & Mechanical --- + "Motor vehicle (road)": "[Vehicle]", + "Car": "[Vehicle]", + "Truck": "[Vehicle]", + "Bus": "[Vehicle]", + "Engine": "[Vehicle]", + "Motorcycle": "[Vehicle]", + "Race car, auto racing": "[Vehicle]", + "Car alarm": "[Car Alarm]", + "Horn": "[Horn/Honking]", + "Car passing by": "[Vehicle]", + "Vehicle horn, car horn, honking": "[Horn/Honking]", + "Bicycle": "[Mechanical]", + "Skateboard": "[Mechanical]", + "Tools": "[Mechanical]", + "Drill": "[Mechanical]", + "Chainsaw": "[Mechanical]", + "Power tool": "[Mechanical]", + + # --- Nature --- + "Rain": "[Rain]", + "Raindrop": "[Rain]", + "Heavy rain": "[Rain]", + "Wind": "[Wind]", + "Rustling leaves": "[Wind]", + "Thunderstorm": "[Thunder]", + "Thunder": "[Thunder]", + "Lightning": "[Thunder]", + "Ocean": "[Water/Ocean]", + "Water": "[Water/Ocean]", + "Stream": "[Water/Ocean]", + "Waterfall": "[Water/Ocean]", + "Fire": "[Fire]", + "Crackle": "[Fire]", + + # --- Animals --- + "Dog": "[Animal Sound]", + "Bark": "[Animal Sound]", + "Howl": "[Animal Sound]", + "Growling": "[Animal Sound]", + "Cat": "[Animal Sound]", + "Meow": "[Animal Sound]", + "Purr": "[Animal Sound]", + "Caterwaul": "[Animal Sound]", + "Roar": "[Animal Sound]", + "Animal": "[Animal Sound]", + "Bird": "[Bird Sound]", + "Crow": "[Bird Sound]", + "Chirp, tweet": "[Bird Sound]", + "Birdsong": "[Bird Sound]", + "Squawk": "[Bird Sound]", + # Snake / reptile sounds — this was the missing category + "Hiss": "[Snake/Hiss]", + "Snake": "[Snake/Hiss]", + "Rattle": "[Snake/Hiss]", + "Rattlesnake": "[Snake/Hiss]", + "Insect": "[Insect Sound]", + "Cricket": "[Insect Sound]", + "Mosquito": "[Insect Sound]", + + # --- Human Non-Speech --- + "Crowd": "[Crowd]", + "Cheering": "[Crowd]", + "Applause": "[Applause]", + "Clapping": "[Applause]", + "Laughter": "[Laughter]", + "Chuckle, chortle": "[Laughter]", + "Giggle": "[Laughter]", + "Crying, sobbing": "[Crying]", + "Whimper": "[Crying]", + "Screaming": "[Scream]", + "Shout": "[Shout]", + "Whistling": "[Whistle]", + "Walk, footsteps": "[Footsteps]", + "Run": "[Footsteps]", + "Gasp": "[Gasp]", + "Groan": "[Groan]", + "Snoring": "[Snoring]", + "Cough": "[Cough]", + "Sneeze": "[Sneeze]", + + # --- Emergency --- + "Siren": "[Siren]", + "Emergency vehicle": "[Siren]", + "Police car (siren)": "[Siren]", + "Ambulance (siren)": "[Siren]", + "Fire engine, fire truck (siren)": "[Siren]", + "Alarm": "[Alarm]", + "Smoke detector, smoke alarm": "[Alarm]", + "Beeping": "[Alarm]", + + # --- Music --- + "Music": "[Music]", + "Musical instrument": "[Music]", + "Singing": "[Music]", + "Drum": "[Music]", + "Guitar": "[Music]", + "Piano": "[Music]", +} + +# Classes not in LABEL_GROUPS collapse to this — prevents raw YAMNet labels +# (like 'Inside, small room') from inflating group scores and beating real events. +DEFAULT_GROUP = "_ambient_" + +# Internal: classes that are ambient/background and should be ignored entirely +_IGNORE_GROUPS = {"_ambient_"} + +def get_canonical_label(yamnet_label: str) -> str: + """Map a raw YAMNet display_name to a canonical CC label. + Returns DEFAULT_GROUP for unmapped ambient classes. + """ + return LABEL_GROUPS.get(yamnet_label, DEFAULT_GROUP) diff --git a/cc_tool/audio/models.py b/cc_tool/audio/models.py new file mode 100644 index 0000000..26ad2f2 --- /dev/null +++ b/cc_tool/audio/models.py @@ -0,0 +1,19 @@ +""" +Intelligent CC Suggestion Tool — Audio Event Data Models +""" +from dataclasses import dataclass + +@dataclass +class AudioEvent: + label: str + confidence: float + start_sec: float + end_sec: float + + def to_dict(self) -> dict: + return { + "label": self.label, + "confidence": round(self.confidence, 4), + "start_sec": round(self.start_sec, 3), + "end_sec": round(self.end_sec, 3), + } diff --git a/cc_tool/audio/utils.py b/cc_tool/audio/utils.py new file mode 100644 index 0000000..05616e1 --- /dev/null +++ b/cc_tool/audio/utils.py @@ -0,0 +1,25 @@ +import numpy as np + +def chunk_audio(waveform, sample_rate, window_sec=1.0, stride_sec=0.5): + """Split audio into overlapping windows for YAMNet.""" + window_samples = int(window_sec * sample_rate) + stride_samples = int(stride_sec * sample_rate) + total_samples = len(waveform) + + chunks = [] + start = 0 + while start + window_samples <= total_samples: + end = start + window_samples + start_sec = start / sample_rate + end_sec = end / sample_rate + chunks.append((start_sec, end_sec, waveform[start:end])) + start += stride_samples + return chunks + +def normalize_waveform(waveform): + """Normalize audio to [-1.0, 1.0].""" + waveform = waveform.astype(np.float32) + max_val = np.max(np.abs(waveform)) + if max_val > 0: + waveform = waveform / max_val + return waveform diff --git a/cc_tool/vision/frame_extractor.py b/cc_tool/vision/frame_extractor.py new file mode 100644 index 0000000..ad17d40 --- /dev/null +++ b/cc_tool/vision/frame_extractor.py @@ -0,0 +1,36 @@ +import cv2 +import os +from pathlib import Path +from typing import List, Dict + +def extract_frames_at(video_path: str, timestamps: List[float], output_dir: str = "outputs/frames") -> Dict[float, List[str]]: + """ + Extracts frames in a wider window around the timestamp to catch late reactions. + """ + os.makedirs(output_dir, exist_ok=True) + cap = cv2.VideoCapture(video_path) + + results = {} + video_name = Path(video_path).stem + + for ts in timestamps: + # We now look from -0.1s to +1.0s after the sound + # (People often take 0.5s to react) + offsets = [-0.1, 0.2, 0.5, 0.8, 1.2] + frame_paths = [] + + for offset in offsets: + target_ts = max(0, ts + offset) + cap.set(cv2.CAP_PROP_POS_MSEC, target_ts * 1000) + success, frame = cap.read() + + if success: + filename = f"{video_name}_{ts:0.2f}_offset_{offset:0.1f}.jpg" + path = os.path.join(output_dir, filename) + cv2.imwrite(path, frame) + frame_paths.append(path) + + results[ts] = frame_paths + + cap.release() + return results diff --git a/cc_tool/vision/models.py b/cc_tool/vision/models.py new file mode 100644 index 0000000..e411b6e --- /dev/null +++ b/cc_tool/vision/models.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from typing import Optional + +@dataclass +class ReactionResult: + event_index: int + reaction_type: str + confidence: float + frame_path: Optional[str] + + def to_dict(self) -> dict: + return { + "event_index": self.event_index, + "reaction_type": self.reaction_type, + "confidence": round(self.confidence, 4), + "frame_path": self.frame_path, + } diff --git a/cc_tool/vision/reaction_detector.py b/cc_tool/vision/reaction_detector.py new file mode 100644 index 0000000..05766e7 --- /dev/null +++ b/cc_tool/vision/reaction_detector.py @@ -0,0 +1,82 @@ +import cv2 +import mediapipe as mp +import numpy as np +from typing import List, Tuple, Optional +from cc_tool.vision.models import ReactionResult + +class ReactionDetector: + def __init__(self): + self.mp_pose = mp.solutions.pose + self.mp_face_mesh = mp.solutions.face_mesh + + self.pose = self.mp_pose.Pose(static_image_mode=True, model_complexity=1) + self.face_mesh = self.mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1) + + def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): + self.pose.close() + self.face_mesh.close() + + def _get_baseline_landmarks(self, frame_path) -> Optional[np.ndarray]: + """ + Hybrid detector: Tries Pose first, falls back to FaceMesh for head/shoulders. + """ + image = cv2.imread(frame_path) + if image is None: return None + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # 1. Try Pose (Full/Half body) + pose_res = self.pose.process(rgb) + if pose_res.pose_landmarks: + lm = pose_res.pose_landmarks.landmark + # Track Shoulders and Hips + return np.array([[lm[i].x, lm[i].y] for i in [11, 12, 23, 24]]) + + # 2. Fallback to FaceMesh (Head only) + face_res = self.face_mesh.process(rgb) + if face_res.multi_face_landmarks: + lm = face_res.multi_face_landmarks[0].landmark + # Track Nose and Eyes + return np.array([[lm[i].x, lm[i].y] for i in [1, 33, 263]]) + + return None + + def analyze_frames(self, event_idx: int, frame_paths: List[str]) -> ReactionResult: + """ + BUG FIX: + - Compares all frames to pre-event baseline (Frame 0). + - Lower sensitivity thresholds (3-8% body height). + """ + if len(frame_paths) < 2: + return ReactionResult(event_idx, "none", 0.0, None) + + # Baseline: The state of the person BEFORE/AT the sound starts + baseline = self._get_baseline_landmarks(frame_paths[0]) + if baseline is None: + return ReactionResult(event_idx, "none", 0.0, None) + + max_movement = 0.0 + for path in frame_paths[1:]: + current = self._get_baseline_landmarks(path) + if current is not None and current.shape == baseline.shape: + # Calculate movement relative to the baseline + dist = np.linalg.norm(current - baseline, axis=1) + max_movement = max(max_movement, np.mean(dist)) + + # BUG FIX: Lowered thresholds for "startle" detection + # 0.015 - 0.03 (1.5% - 3%) = Posture shift / Startle + # > 0.08 (8%) = Significant movement + confidence = min(1.0, max_movement * 15) + + reaction_type = "none" + if max_movement > 0.06: + reaction_type = "significant_movement" + elif max_movement > 0.015: + reaction_type = "reaction/startle" + + return ReactionResult( + event_index=event_idx, + reaction_type=reaction_type, + confidence=confidence, + frame_path=frame_paths[len(frame_paths)//2] + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0fc92ce --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +# Goal 1 — Sound Event Detection +tensorflow>=2.15.0 +tensorflow_hub>=0.16.1 +numpy<2.0.0 +soundfile>=0.12.1 +pydantic>=2.0.0 + +# Goal 2 — Speaker Reaction Detection +opencv-python>=4.8.0 +mediapipe>=0.10.0 +Pillow>=10.0.0 diff --git a/test_goal_1.py b/test_goal_1.py new file mode 100644 index 0000000..63525b2 --- /dev/null +++ b/test_goal_1.py @@ -0,0 +1,25 @@ +from cc_tool.audio.extractor import extract_audio +from cc_tool.audio.detector import SoundEventDetector +import sys +import os + +def test_goal_1(video_path): + print(f"Testing Goal 1 with: {video_path}") + + # 1. Extract + wav_path = extract_audio(video_path) + print(f"Audio extracted to: {wav_path}") + + # 2. Detect + detector = SoundEventDetector(confidence_threshold=0.3) + events = detector.detect(wav_path) + + print(f"\nDetected {len(events)} Events:") + for e in events: + print(f"[{e.start_sec:0.2f}s - {e.end_sec:0.2f}s] {e.label} (Conf: {e.confidence:0.2f})") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python test_goal_1.py ") + else: + test_goal_1(sys.argv[1]) diff --git a/test_goal_2.py b/test_goal_2.py new file mode 100644 index 0000000..a5cc860 --- /dev/null +++ b/test_goal_2.py @@ -0,0 +1,46 @@ +from cc_tool.audio.detector import SoundEventDetector +from cc_tool.vision.frame_extractor import extract_frames_at +from cc_tool.vision.reaction_detector import ReactionDetector +import sys +import os + +def test_goal_2(video_path): + print(f"--- Starting Goal 2 Analysis ---") + print(f"Video: {video_path}") + + # 1. Detect Audio Events (Goal 1) + # We use the existing wav if available + wav_path = f"outputs/audio/{os.path.basename(video_path).split('.')[0]}.wav" + if not os.path.exists(wav_path): + from cc_tool.audio.extractor import extract_audio + wav_path = extract_audio(video_path) + + detector = SoundEventDetector(confidence_threshold=0.3) + audio_events = detector.detect(wav_path) + + if not audio_events: + print("No audio events found to analyze.") + return + + # 2. Extract Frames around those events + print(f"Found {len(audio_events)} audio events. Extracting frames...") + timestamps = [e.start_sec for e in audio_events] + frames_map = extract_frames_at(video_path, timestamps) + + # 3. Analyze Reactions + print("Analyzing visual reactions...") + results = [] + with ReactionDetector() as reaction_detector: + for idx, event in enumerate(audio_events): + frame_paths = frames_map.get(event.start_sec, []) + reaction = reaction_detector.analyze_frames(idx, frame_paths) + results.append((event, reaction)) + + print(f"\n[{event.start_sec:0.2f}s] Sound: {event.label}") + print(f" Visual Reaction: {reaction.reaction_type} (Conf: {reaction.confidence:0.2f})") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python test_goal_2.py ") + else: + test_goal_2(sys.argv[1]) diff --git a/tests/test_audio/__init__.py b/tests/test_audio/__init__.py new file mode 100644 index 0000000..75dce4d --- /dev/null +++ b/tests/test_audio/__init__.py @@ -0,0 +1 @@ +# test packages diff --git a/tests/test_audio/test_detector.py b/tests/test_audio/test_detector.py new file mode 100644 index 0000000..908e2f2 --- /dev/null +++ b/tests/test_audio/test_detector.py @@ -0,0 +1,40 @@ +"""Tests for audio chunking utils — PR 1""" +import numpy as np +import pytest + +from cc_tool.audio.utils import chunk_audio, normalize_waveform + + +def test_chunk_audio_basic(): + sr = 16000 + duration_sec = 5.0 + waveform = np.zeros(int(sr * duration_sec), dtype=np.float32) + chunks = chunk_audio(waveform, sr, window_sec=1.0, stride_sec=0.5) + + # With 5s audio, 1s window, 0.5s stride: 9 full windows + assert len(chunks) == 9 + for start_sec, end_sec, chunk in chunks: + assert end_sec > start_sec + assert len(chunk) == sr # 1s window = 16000 samples + + +def test_chunk_audio_short_clip(): + """Audio shorter than one window → no chunks.""" + sr = 16000 + waveform = np.zeros(int(sr * 0.3), dtype=np.float32) + chunks = chunk_audio(waveform, sr, window_sec=1.0, stride_sec=0.5) + assert len(chunks) == 0 + + +def test_normalize_waveform_range(): + waveform = np.array([0, 100, -200, 50], dtype=np.float32) + result = normalize_waveform(waveform) + assert result.max() <= 1.0 + assert result.min() >= -1.0 + + +def test_normalize_waveform_silent(): + """Silent waveform should return all zeros without division error.""" + waveform = np.zeros(1000, dtype=np.float32) + result = normalize_waveform(waveform) + assert np.all(result == 0.0) diff --git a/tests/test_audio/test_extractor.py b/tests/test_audio/test_extractor.py new file mode 100644 index 0000000..ee3b1a6 --- /dev/null +++ b/tests/test_audio/test_extractor.py @@ -0,0 +1,23 @@ +"""Tests for audio extractor — PR 1""" +import os +import struct +import wave +import pytest + +from cc_tool.audio.extractor import extract_audio + + +def _make_dummy_wav(path: str, duration_sec: float = 2.0, sample_rate: int = 16000) -> None: + """Create a minimal silent WAV file for testing.""" + n_samples = int(duration_sec * sample_rate) + with wave.open(path, "w") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) # 16-bit + wf.setframerate(sample_rate) + wf.writeframes(struct.pack("<" + "h" * n_samples, *([0] * n_samples))) + + +def test_extract_audio_missing_file(tmp_path): + """Should raise FileNotFoundError for non-existent input.""" + with pytest.raises(FileNotFoundError): + extract_audio(str(tmp_path / "nonexistent.mp4")) diff --git a/tests/test_decision/__init__.py b/tests/test_decision/__init__.py new file mode 100644 index 0000000..75dce4d --- /dev/null +++ b/tests/test_decision/__init__.py @@ -0,0 +1 @@ +# test packages diff --git a/tests/test_decision/test_combiner.py b/tests/test_decision/test_combiner.py new file mode 100644 index 0000000..d38799a --- /dev/null +++ b/tests/test_decision/test_combiner.py @@ -0,0 +1,69 @@ +"""Tests for the CC Decision Combiner — PR 3""" +import pytest + +from cc_tool.audio.models import AudioEvent +from cc_tool.vision.models import ReactionResult +from cc_tool.decision.combiner import combine_and_decide + + +def _make_event(label: str, conf: float, start: float = 0.0, end: float = 1.0) -> AudioEvent: + return AudioEvent(label=label, confidence=conf, start_sec=start, end_sec=end) + + +def _make_reaction(idx: int, conf: float) -> ReactionResult: + return ReactionResult(event_index=idx, reaction_type="head_turn", confidence=conf, frame_path=None) + + +def test_above_threshold_emits_cc(): + events = [_make_event("Honking", 0.9, 1.0, 2.0)] + reactions = [_make_reaction(0, 0.8)] + annotations = combine_and_decide(events, reactions, cc_threshold=0.55) + assert len(annotations) == 1 + assert annotations[0].text == "[honking]" + assert annotations[0].index == 1 + + +def test_below_threshold_no_cc(): + # audio=0.3, reaction=0.2 → final = 0.6*0.3 + 0.4*0.2 = 0.18 + 0.08 = 0.26 + events = [_make_event("Rain", 0.3, 5.0, 6.0)] + reactions = [_make_reaction(0, 0.2)] + annotations = combine_and_decide(events, reactions, cc_threshold=0.55) + assert len(annotations) == 0 + + +def test_no_reactions_uses_audio_weight_only(): + # audio=0.9, reaction=0 → final = 0.6*0.9 = 0.54 (below 0.55) + events = [_make_event("Gunshot", 0.9, 2.0, 3.0)] + annotations = combine_and_decide(events, [], cc_threshold=0.55) + assert len(annotations) == 0 # 0.54 < 0.55 + + +def test_no_reactions_above_threshold_with_lower_gate(): + # audio=1.0, reaction=0 → final = 0.6 >= 0.5 + events = [_make_event("Explosion", 1.0, 3.0, 4.0)] + annotations = combine_and_decide(events, [], cc_threshold=0.5) + assert len(annotations) == 1 + + +def test_multiple_events_mixed(): + events = [ + _make_event("Honking", 0.9, 1.0, 2.0), # should pass + _make_event("Wind", 0.2, 5.0, 6.0), # should fail + _make_event("Alarm", 0.8, 10.0, 11.0), # should pass + ] + reactions = [ + _make_reaction(0, 0.7), # strong reaction + _make_reaction(1, 0.1), # weak reaction + _make_reaction(2, 0.0), # no reaction + ] + annotations = combine_and_decide(events, reactions, cc_threshold=0.55) + assert len(annotations) == 2 + assert annotations[0].text == "[honking]" + assert annotations[1].text == "[alarm]" + + +def test_sequential_indices(): + events = [_make_event("Laughter", 0.8, i, i+1) for i in range(3)] + reactions = [_make_reaction(i, 0.7) for i in range(3)] + annotations = combine_and_decide(events, reactions, cc_threshold=0.4) + assert [ann.index for ann in annotations] == [1, 2, 3] diff --git a/tests/test_decision/test_srt_writer.py b/tests/test_decision/test_srt_writer.py new file mode 100644 index 0000000..752c1cf --- /dev/null +++ b/tests/test_decision/test_srt_writer.py @@ -0,0 +1,47 @@ +"""Tests for SRT writer — PR 3""" +import os +import pytest + +from cc_tool.decision.models import CCAnnotation +from cc_tool.export.srt_writer import write_srt, _sec_to_srt_time + + +def _make_ann(idx: int, start: float, end: float, text: str) -> CCAnnotation: + return CCAnnotation( + index=idx, start_sec=start, end_sec=end, text=text, + audio_conf=0.8, reaction_conf=0.7, final_score=0.74, + ) + + +def test_sec_to_srt_time_basic(): + assert _sec_to_srt_time(0.0) == "00:00:00,000" + assert _sec_to_srt_time(61.5) == "00:01:01,500" + assert _sec_to_srt_time(3661.123) == "01:01:01,123" + + +def test_write_srt_creates_file(tmp_path): + anns = [ + _make_ann(1, 4.2, 5.1, "[honking]"), + _make_ann(2, 12.0, 13.5, "[laughter]"), + ] + out = str(tmp_path / "output.srt") + write_srt(anns, out) + assert os.path.exists(out) + + +def test_write_srt_content(tmp_path): + anns = [_make_ann(1, 4.2, 5.1, "[honking]")] + out = str(tmp_path / "test.srt") + write_srt(anns, out) + + content = open(out, encoding="utf-8").read() + assert "1" in content + assert "00:00:04,200 --> 00:00:05,100" in content + assert "[honking]" in content + + +def test_write_srt_empty(tmp_path): + out = str(tmp_path / "empty.srt") + write_srt([], out) + assert os.path.exists(out) + assert open(out).read().strip() == ""