Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 159 additions & 27 deletions noise-canceller.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
from livekit.plugins.ai_coustics import EnhancerModel
from dotenv import load_dotenv

SAMPLERATE = 48000
DEFAULT_SAMPLERATE = 48_000
CHUNK_DURATION_MS = 10 # 10ms chunks
SAMPLES_PER_CHUNK = int(SAMPLERATE * CHUNK_DURATION_MS / 1000)
CHANNELS = 1

load_dotenv()
Expand Down Expand Up @@ -188,6 +187,8 @@ async def entrypoint(ctx: JobContext):
filter_key=fc["filter"],
use_webrtc=fc["use_webrtc"],
silent=silent,
direct=_config.get("direct", False),
sample_rate=_config.get("sample_rate", DEFAULT_SAMPLERATE),
)

orig_stream: SttStream | None = None
Expand Down Expand Up @@ -336,6 +337,8 @@ def __init__(
filter_key: str,
use_webrtc=False,
silent=False,
direct=False,
sample_rate: int = DEFAULT_SAMPLERATE,
):
self.room = room
self.noise_filter = noise_filter
Expand All @@ -344,6 +347,9 @@ def __init__(
self.processed_frames: list[bytes] = []
self.original_audio: np.ndarray | None = None
self.silent = silent
self.direct = direct
self.sample_rate = sample_rate
self.samples_per_chunk = int(sample_rate * CHUNK_DURATION_MS / 1000)

async def process_file(
self,
Expand Down Expand Up @@ -416,6 +422,12 @@ async def process_file(
original_stt=original_stt,
processed_stt=processed_stt,
)
elif self.direct:
await self._process_direct(
audio_data,
progress=progress,
bar_ids=bar_ids,
)
else:
await self._process_with_noise_cancellation(
audio_data,
Expand All @@ -442,8 +454,8 @@ async def _process_with_webrtc_apm(
processed_stt: "SttStream | None" = None,
):
"""Process audio data using WebRTC AudioProcessingModule"""
chunk_count = len(audio_data) // SAMPLES_PER_CHUNK
if len(audio_data) % SAMPLES_PER_CHUNK != 0:
chunk_count = len(audio_data) // self.samples_per_chunk
if len(audio_data) % self.samples_per_chunk != 0:
chunk_count += 1

if not self.silent:
Expand Down Expand Up @@ -483,21 +495,21 @@ async def _process_with_webrtc_apm(
prog.update(tid, total=chunk_count)

for i in range(chunk_count):
start_idx = i * SAMPLES_PER_CHUNK
end_idx = min(start_idx + SAMPLES_PER_CHUNK, len(audio_data))
start_idx = i * self.samples_per_chunk
end_idx = min(start_idx + self.samples_per_chunk, len(audio_data))
chunk = audio_data[start_idx:end_idx]

if len(chunk) < SAMPLES_PER_CHUNK:
if len(chunk) < self.samples_per_chunk:
chunk = np.concatenate(
[
chunk,
np.zeros(SAMPLES_PER_CHUNK - len(chunk), dtype=np.int16),
np.zeros(self.samples_per_chunk - len(chunk), dtype=np.int16),
]
)

audio_frame = rtc.AudioFrame(
data=chunk.tobytes(),
sample_rate=SAMPLERATE,
sample_rate=self.sample_rate,
num_channels=CHANNELS,
samples_per_channel=len(chunk),
)
Expand All @@ -512,9 +524,9 @@ async def _process_with_webrtc_apm(
if processed_stt is not None:
processed_frame = rtc.AudioFrame(
data=processed_bytes,
sample_rate=SAMPLERATE,
sample_rate=self.sample_rate,
num_channels=CHANNELS,
samples_per_channel=SAMPLES_PER_CHUNK,
samples_per_channel=self.samples_per_chunk,
)
processed_stt.push_frame(processed_frame)

Expand All @@ -530,6 +542,95 @@ async def _process_with_webrtc_apm(
f"Successfully processed {len(self.processed_frames)} frames with WebRTC APM"
)

async def _process_direct(
self,
audio_data,
progress=None,
bar_ids: dict[str, int] | None = None,
):
"""Process audio directly through the FrameProcessor, bypassing the SFU.

This avoids Opus encode/decode and produces output identical to direct
plugin FFI processing. Useful for bit-exact comparison testing.
"""
chunk_count = len(audio_data) // self.samples_per_chunk
if len(audio_data) % self.samples_per_chunk != 0:
chunk_count += 1

# Set up credentials on the FrameProcessor so the underlying Enhancer
# can authenticate with the ai-coustics service.
token = (
api.AccessToken(
os.environ["LIVEKIT_API_KEY"],
os.environ["LIVEKIT_API_SECRET"],
)
.with_identity("noise-canceller-direct")
.with_grants(api.VideoGrants(room_join=True, room=self.room.name))
.to_jwt()
)
self.noise_filter._on_credentials_updated(
token=token, url=os.environ["LIVEKIT_URL"]
)
self.noise_filter._on_stream_info_updated(
room_name=self.room.name,
participant_identity="direct-processing",
publication_sid="direct-sid",
)

if progress is None:
progress_class = NullProgress if self.silent else Progress
ctx = progress_class(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
console=console,
)
else:
ctx = nullcontext(progress)

ids = bar_ids or {}

with ctx as prog:
if not ids:
ids["nc"] = prog.add_task(
" 🎤 Processing directly (no SFU)", total=chunk_count
)
else:
for tid in ids.values():
prog.update(tid, total=chunk_count)

for i in range(chunk_count):
start_idx = i * self.samples_per_chunk
end_idx = min(start_idx + self.samples_per_chunk, len(audio_data))
chunk = audio_data[start_idx:end_idx]

if len(chunk) < self.samples_per_chunk:
chunk = np.concatenate(
[
chunk,
np.zeros(self.samples_per_chunk - len(chunk), dtype=np.int16),
]
)

audio_frame = rtc.AudioFrame(
data=chunk.tobytes(),
sample_rate=self.sample_rate,
num_channels=CHANNELS,
samples_per_channel=len(chunk),
)

processed_frame = self.noise_filter._process(audio_frame)
self.processed_frames.append(processed_frame.data)

for tid in ids.values():
prog.update(tid, advance=1)

logger.info(
"Direct processing: %d frames processed", len(self.processed_frames)
)

async def _process_with_noise_cancellation(
self,
audio_data,
Expand All @@ -550,8 +651,8 @@ async def _process_with_noise_cancellation(
|
CapturingAudioInput
"""
chunk_count = len(audio_data) // SAMPLES_PER_CHUNK
if len(audio_data) % SAMPLES_PER_CHUNK != 0:
chunk_count = len(audio_data) // self.samples_per_chunk
if len(audio_data) % self.samples_per_chunk != 0:
chunk_count += 1

publisher_room: rtc.Room | None = None
Expand All @@ -572,7 +673,7 @@ async def _process_with_noise_cancellation(
await publisher_room.connect(os.environ["LIVEKIT_URL"], publisher_token)
logger.debug("Publisher connected to room %s", self.room.name)

file_source = FileAudioSource(audio_data, SAMPLERATE, CHANNELS)
file_source = FileAudioSource(audio_data, self.sample_rate, CHANNELS)
input_track = rtc.LocalAudioTrack.create_audio_track(
"raw-input", file_source
)
Expand Down Expand Up @@ -601,7 +702,7 @@ async def _process_with_noise_cancellation(
room_options=room_io.RoomOptions(
audio_input=room_io.AudioInputOptions(
noise_cancellation=self.noise_filter,
sample_rate=SAMPLERATE,
sample_rate=self.sample_rate,
num_channels=CHANNELS,
frame_size_ms=CHUNK_DURATION_MS,
),
Expand Down Expand Up @@ -718,23 +819,23 @@ async def _feed_audio_data_with_progress(
):
"""Feed audio data to the source with precise timing and progress updates."""
ids = task_ids if isinstance(task_ids, list) else [task_ids]
chunk_duration = SAMPLES_PER_CHUNK / SAMPLERATE
chunk_duration = self.samples_per_chunk / self.sample_rate
loop = asyncio.get_running_loop()
start_time = loop.time()

for i in range(chunk_count):
start_idx = i * SAMPLES_PER_CHUNK
end_idx = min(start_idx + SAMPLES_PER_CHUNK, len(audio_data))
start_idx = i * self.samples_per_chunk
end_idx = min(start_idx + self.samples_per_chunk, len(audio_data))
chunk = audio_data[start_idx:end_idx]

if len(chunk) < SAMPLES_PER_CHUNK:
if len(chunk) < self.samples_per_chunk:
chunk = np.concatenate(
[chunk, np.zeros(SAMPLES_PER_CHUNK - len(chunk), dtype=np.int16)]
[chunk, np.zeros(self.samples_per_chunk - len(chunk), dtype=np.int16)]
)

audio_frame = rtc.AudioFrame(
data=chunk.tobytes(),
sample_rate=SAMPLERATE,
sample_rate=self.sample_rate,
num_channels=CHANNELS,
samples_per_channel=len(chunk),
)
Expand Down Expand Up @@ -787,11 +888,11 @@ def _load_audio_file(self, input_path: Path):
audio_array = audio_data

# Resample to 48kHz mono if needed
if sample_rate != SAMPLERATE or channels != CHANNELS:
if sample_rate != self.sample_rate or channels != CHANNELS:
audio_array = self._resample_audio(audio_array, sample_rate, channels)
if not self.silent:
console.print(
f"🔄 [yellow]Resampled to: {SAMPLERATE}Hz, {CHANNELS} channel(s)[/yellow]"
f"🔄 [yellow]Resampled to: {self.sample_rate}Hz, {CHANNELS} channel(s)[/yellow]"
)
console.print()

Expand Down Expand Up @@ -822,10 +923,10 @@ def _resample_audio(self, audio_array, original_rate, original_channels):
audio_array = stereo.mean(axis=1).astype(np.int16)

# Resample if needed
if original_rate != SAMPLERATE:
if original_rate != self.sample_rate:
resampler = rtc.AudioResampler(
input_rate=original_rate,
output_rate=SAMPLERATE,
output_rate=self.sample_rate,
num_channels=1,
quality=rtc.AudioResamplerQuality.VERY_HIGH,
)
Expand Down Expand Up @@ -863,7 +964,7 @@ def _save_output(self, output_path: Path):
with wave.open(str(output_path), "wb") as wav_file:
wav_file.setnchannels(CHANNELS)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(SAMPLERATE)
wav_file.setframerate(self.sample_rate)

for frame_data in self.processed_frames:
wav_file.writeframes(frame_data)
Expand All @@ -872,7 +973,7 @@ def _save_output(self, output_path: Path):
class FileAudioSource(rtc.AudioSource):
"""Custom audio source that streams from file data"""

def __init__(self, audio_data, sample_rate=SAMPLERATE, num_channels=CHANNELS):
def __init__(self, audio_data, sample_rate=DEFAULT_SAMPLERATE, num_channels=CHANNELS):
super().__init__(sample_rate, num_channels)
self.audio_data = audio_data

Expand Down Expand Up @@ -1243,9 +1344,38 @@ def main():
help="LiveKit Inference STT model (default: deepgram/nova-3:en). "
"Format: provider/model[:language]",
)
parser.add_argument(
"--sample-rate",
type=int,
default=DEFAULT_SAMPLERATE,
help=f"Output sample rate in Hz (default: {DEFAULT_SAMPLERATE}). "
"Input audio is resampled to this rate before processing.",
)
parser.add_argument(
"--direct",
action="store_true",
help="Process audio directly through the plugin's FrameProcessor "
"without routing through the LiveKit SFU. Bypasses Opus "
"encode/decode so output is bit-exact with direct FFI processing. "
"Only compatible with ai-coustics filters (aic-quail-l, aic-quail-vfl).",
)

args = parser.parse_args()

# --direct is only meaningful for ai-coustics FrameProcessor filters.
_AIC_FILTERS = {"aic-quail-l", "aic-quail-vfl"}
if args.direct:
if args.filter == "all":
parser.error(
"--direct cannot be used with --filter all (it only supports "
"ai-coustics filters: aic-quail-l, aic-quail-vfl)"
)
if args.filter not in _AIC_FILTERS:
parser.error(
f"--direct is only supported with ai-coustics filters "
f"({', '.join(sorted(_AIC_FILTERS))}), not '{args.filter}'"
)

# Setup console for silent mode
if args.silent:
console = NullConsole()
Expand Down Expand Up @@ -1339,6 +1469,8 @@ def main():
"silent": args.silent,
"transcript": args.transcript,
"stt": args.stt,
"direct": args.direct,
"sample_rate": args.sample_rate,
}
)

Expand Down