diff --git a/examples/models/diar_streaming_sortformer/CMakeLists.txt b/examples/models/diar_streaming_sortformer/CMakeLists.txt new file mode 100644 index 00000000000..36364eb5f34 --- /dev/null +++ b/examples/models/diar_streaming_sortformer/CMakeLists.txt @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(diar_streaming_sortformer_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# Let files say "include " +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# Need this for gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# Find executorch libraries +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# Common ops for all builds +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# CPU-only builds need quantized and custom ops +if(NOT EXECUTORCH_BUILD_CUDA AND MSVC) + list(APPEND link_libraries quantized_ops_lib custom_ops) + executorch_target_link_options_shared_lib(quantized_ops_lib) + executorch_target_link_options_shared_lib(custom_ops) +endif() + +# XNNPACK +if(TARGET xnnpack_backend) + set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) + if(TARGET kleidiai) + list(APPEND xnnpack_backend_libs kleidiai) + endif() + list(APPEND link_libraries ${xnnpack_backend_libs}) + executorch_target_link_options_shared_lib(xnnpack_backend) +endif() + +# Needed for cpuinfo where it uses android specific log lib +if(ANDROID) + list(APPEND link_libraries log) +endif() + +# Add the required ExecuTorch extensions +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# Link CUDA backend (optional for future delegate use) +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() +endif() + +if(EXECUTORCH_BUILD_METAL) + list(APPEND link_libraries metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() + +add_executable(diar_streaming_sortformer_runner main.cpp) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(diar_streaming_sortformer_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(diar_streaming_sortformer_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories( + diar_streaming_sortformer_runner PUBLIC ${_common_include_directories} +) +target_link_libraries(diar_streaming_sortformer_runner PUBLIC ${link_libraries}) +target_compile_options( + diar_streaming_sortformer_runner PUBLIC ${_common_compile_options} +) + +# On Windows, copy required DLLs to the executable directory +if(MSVC AND EXECUTORCH_BUILD_CUDA) + add_custom_command( + TARGET diar_streaming_sortformer_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ + $ + COMMENT "Copying aoti_cuda_shims.dll to diar_streaming_sortformer_runner directory" + ) +endif() + diff --git a/examples/models/diar_streaming_sortformer/CMakePresets.json b/examples/models/diar_streaming_sortformer/CMakePresets.json new file mode 100644 index 00000000000..d33cfb77f77 --- /dev/null +++ b/examples/models/diar_streaming_sortformer/CMakePresets.json @@ -0,0 +1,45 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "diar_streaming_sortformer-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/diar_streaming_sortformer", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "diar_streaming_sortformer-cpu", + "displayName": "Streaming Sortformer diarization runner (CPU)", + "inherits": ["diar_streaming_sortformer-base"] + } + ], + "buildPresets": [ + { + "name": "diar_streaming_sortformer-cpu", + "displayName": "Build Streaming Sortformer diarization runner (CPU)", + "configurePreset": "diar_streaming_sortformer-cpu", + "targets": ["diar_streaming_sortformer_runner"] + } + ], + "workflowPresets": [ + { + "name": "diar_streaming_sortformer-cpu", + "displayName": "Configure and build Streaming Sortformer diarization runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "diar_streaming_sortformer-cpu" + }, + { + "type": "build", + "name": "diar_streaming_sortformer-cpu" + } + ] + } + ] +} + diff --git a/examples/models/diar_streaming_sortformer/README.md b/examples/models/diar_streaming_sortformer/README.md new file mode 100644 index 00000000000..12c62346ce3 --- /dev/null +++ b/examples/models/diar_streaming_sortformer/README.md @@ -0,0 +1,98 @@ +# Streaming Sortformer diarization (ExecuTorch C++) + +This example exports `nvidia/diar_streaming_sortformer_4spk-v2.1` to an ExecuTorch `.pte` (portable ops only), +then runs offline **streaming-style** diarization in C++ (chunk-by-chunk) using `model_step`. + +## Export (portable ops only) + +From `executorch/`: + +```bash +python examples/models/diar_streaming_sortformer/export_diar_streaming_sortformer.py \ + --output-dir ./sortformer_diar_exports +``` + +Artifacts: +- `./sortformer_diar_exports/model.pte` + +## Build + run the C++ runner + +This uses the same pattern as `examples/models/parakeet`. + +```bash +# Build ExecuTorch + this runner +make diar-streaming-sortformer-cpu + +./cmake-out/examples/models/diar_streaming_sortformer/diar_streaming_sortformer_runner \ + --model_path ./sortformer_diar_exports/model.pte \ + --audio_path /path/to/mono_16khz.wav \ + --threshold 0.5 +``` + +Notes: +- The WAV loader expects **mono** audio and does **not** resample. +- This runner implements a simplified cache update (keeps the most recent cache frames) and does not implement + NeMo's speaker-cache compression logic. + + +## Notes about difference between simplified cache and NeMo: +What “cache update is simplified” means (in our C++ runner) + +- We only keep embedding buffers fifo and spkcache and update them as plain “recent history” arrays. See examples/models/ + diar_streaming_sortformer/main.cpp:277 and the update block at examples/models/diar_streaming_sortformer/main.cpp:384. +- When a new chunk arrives, we append chunk_embs into fifo. If fifo would overflow, we “pop” the oldest frames and append them + into spkcache (and cap spkcache by just dropping the oldest frames). See examples/models/diar_streaming_sortformer/main.cpp:389. +- We do not track the extra streaming state NeMo uses to manage / compress caches: + - no spkcache_preds / fifo_preds (posteriors aligned to cached frames), + - no mean_sil_emb / n_sil_frames (silence profile), + - no speaker-cache compression based on per-speaker importance scores. +- Practical impact: spkcache becomes “whatever frames happened most recently”, not “a balanced, speaker-representative memory”. + This usually hurts long-form diarization and speaker re-entry stability. + +What NeMo does (the missing pieces) +NeMo’s real streaming update is in /Users/matt/Workspace/NeMo/nemo/collections/asr/modules/sortformer_modules.py:395: + +- It maintains a richer streaming state: spkcache, spkcache_lengths, spkcache_preds, fifo, fifo_lengths, fifo_preds, plus + mean_sil_emb/n_sil_frames (init_streaming_state is at /Users/matt/Workspace/NeMo/nemo/collections/asr/modules/ + sortformer_modules.py:360). +- Every step it refreshes fifo_preds from the newly computed preds for the [spkcache + fifo + chunk] sequence. (streaming_update + shows the slice logic clearly at /Users/matt/Workspace/NeMo/nemo/collections/asr/modules/sortformer_modules.py:562.) +- When FIFO overflows, it: + - pops frames from FIFO → appends them into the speaker cache, + - updates the silence profile based on popped-frame posteriors (_get_silence_profile at /Users/matt/Workspace/NeMo/nemo/ + collections/asr/modules/sortformer_modules.py:636), + - and if speaker cache is too large, it runs _compress_spkcache to select a fixed-size set of “important” frames per speaker + (uses log-score, thresholds, boosts, forced silence frames, topk/sort/gather) at /Users/matt/Workspace/NeMo/nemo/ + collections/asr/modules/sortformer_modules.py:838. + +Next steps to match NeMo (portable-only ExecuTorch inference) + +1. Export more info from model_step so C++ can implement NeMo’s cache logic: + +- Add at least fifo_preds (posteriors for the existing FIFO frames, padded to fifo_len) as an additional output. Today we only + export chunk_preds/chunk_embs/chunk_pred_len (examples/models/diar_streaming_sortformer/ + export_diar_streaming_sortformer.py:169). +- Potentially also export a few more constants (as constant_methods) used by _compress_spkcache: sil_threshold, + pred_score_threshold, spkcache_sil_frames_per_spk, scores_boost_latest, strong_boost_rate, weak_boost_rate, min_pos_scores_rate, + etc. + +2. Track the full NeMo streaming state in C++ + +- Add buffers for fifo_preds and spkcache_preds (and lengths), plus mean_sil_emb and n_sil_frames. + +3. Implement NeMo’s cache update + compression in C++ + +- Port streaming_update_async (or the simpler streaming_update) logic for batch=1, including: + - the exact pop-out length rules, + - silence profile update (_get_silence_profile), + - speaker-cache compression (_compress_spkcache and helpers). +- This is the main reason we left cache update in C++: _compress_spkcache depends on ops like topk/sort that are often painful to + guarantee in “portable kernels only” graphs. + +4. Validation + +- Write a small “step-by-step parity” harness: run NeMo’s forward_streaming_step() and your C++ runner on the same audio/chunking + params and compare per-step chunk_preds and final accumulated posteriors (before any VAD postprocessing). + +If you want, I can outline exactly what extra tensors to output from model_step (names + fixed shapes) to support a faithful C++ +port of streaming_update_async without re-running the whole model twice per step. diff --git a/examples/models/diar_streaming_sortformer/export_diar_streaming_sortformer.py b/examples/models/diar_streaming_sortformer/export_diar_streaming_sortformer.py new file mode 100644 index 00000000000..f996d7d8fb0 --- /dev/null +++ b/examples/models/diar_streaming_sortformer/export_diar_streaming_sortformer.py @@ -0,0 +1,404 @@ +""" +Export nvidia/diar_streaming_sortformer_4spk-v2.1 to ExecuTorch (portable ops only). + +This exports two runtime methods into a single `model.pte`: + - `preprocessor(audio_1d, audio_len) -> (features, features_len)` + where `features` is time-major: [1, T_feat, feat_dim] + - `model_step(chunk, chunk_len, spkcache, spkcache_len, fifo, fifo_len, lc, rc) + -> (chunk_preds, chunk_embs, chunk_pred_len)` + where: + - `chunk` is time-major log-mel features for one step: [1, T_chunk, feat_dim] + - `spkcache` / `fifo` are embedding caches: [1, L, emb_dim] + - `lc` / `rc` are left/right context in diar frames (post-subsampling) + - `chunk_preds` is fixed-size [1, chunk_len_max, n_spk] (padded with zeros) + - `chunk_embs` is fixed-size [1, chunk_len_max, emb_dim] (padded with zeros) + - `chunk_pred_len` is the number of valid diar frames in this step. + +The exported program also includes constant metadata methods (via `constant_methods`) so C++ +doesn't need to hardcode model parameters. +""" + +from __future__ import annotations + +import argparse +import os +from dataclasses import dataclass +from typing import Dict, Tuple + +import torch +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from torch.export import Dim, export + + +@dataclass(frozen=True) +class StreamingConfig: + spkcache_len: int + fifo_len: int + spkcache_update_period: int + chunk_len: int + chunk_left_context: int + chunk_right_context: int + + +def _load_model(model_name_or_path: str): + try: + from nemo.collections.asr.models.sortformer_diar_models import ( + SortformerEncLabelModel, + ) + except Exception as e: # pragma: no cover + raise RuntimeError( + "Failed to import NeMo. Install NeMo (or set PYTHONPATH) to use this exporter." + ) from e + + if model_name_or_path.endswith(".nemo"): + model = SortformerEncLabelModel.restore_from( + restore_path=model_name_or_path, map_location="cpu", strict=False + ) + else: + model = SortformerEncLabelModel.from_pretrained( + model_name_or_path, map_location="cpu" + ) + model.eval() + model.freeze() + + # Streaming-safe preprocessor defaults (match NeMo streaming examples). + if hasattr(model, "preprocessor") and hasattr(model.preprocessor, "featurizer"): + model.preprocessor.featurizer.dither = 0.0 + model.preprocessor.featurizer.pad_to = 0 + + return model + + +class PreprocessorWrapper(torch.nn.Module): + def __init__(self, preprocessor): + super().__init__() + self.preprocessor = preprocessor + + def forward( + self, audio: torch.Tensor, length: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NeMo preprocessors expect (B, T). Export contract uses 1D audio. + audio_signal = audio.unsqueeze(0) + feats, feat_len = self.preprocessor(input_signal=audio_signal, length=length) + # Convert to time-major (B, T, F) for easier slicing in C++. + feats = feats.transpose(1, 2) + return feats, feat_len + + +class SortformerStreamingStep(torch.nn.Module): + """One streaming inference step of Sortformer, with fixed-shape outputs. + + This wrapper keeps the neural network in ExecuTorch and leaves cache update logic to C++. + """ + + def __init__( + self, + diar_model, + cfg: StreamingConfig, + ): + super().__init__() + self.diar_model = diar_model + self.cfg = cfg + + self.spkcache_max_len = int(cfg.spkcache_len) + self.fifo_max_len = int(cfg.fifo_len) + self.chunk_len_max = int(cfg.chunk_len) + self.chunk_total_diar = int( + cfg.chunk_left_context + cfg.chunk_len + cfg.chunk_right_context + ) + self.total_max_len = int( + self.spkcache_max_len + self.fifo_max_len + self.chunk_total_diar + ) + + def _pack_spkcache_fifo_chunk( + self, + spkcache: torch.Tensor, + spkcache_len: torch.Tensor, + fifo: torch.Tensor, + fifo_len: torch.Tensor, + chunk_pre_encode: torch.Tensor, + chunk_pre_encode_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # All tensors are batch=1 with fixed max shapes: + # spkcache: [1, spkcache_max_len, emb_dim], spkcache_len: [1] + # fifo: [1, fifo_max_len, emb_dim], fifo_len: [1] + # chunk: [1, chunk_total_diar, emb_dim], chunk_len: [1] + # + # We need a packed tensor where *valid* frames are contiguous: + # [spkcache[:spk_len], fifo[:fifo_len], chunk[:chunk_len], padding...] + + combined = torch.cat([spkcache, fifo, chunk_pre_encode], dim=1) + emb_dim = combined.size(-1) + + # Length scalars (0-dim tensors) + spk_len0 = spkcache_len[0] + fifo_len0 = fifo_len[0] + chunk_len0 = chunk_pre_encode_len[0] + total_len0 = spk_len0 + fifo_len0 + chunk_len0 + + p = torch.arange(self.total_max_len, device=combined.device, dtype=torch.long) + idx = torch.zeros_like(p) + + # spkcache part + idx = torch.where(p < spk_len0, p, idx) + + # fifo part (packed immediately after spk_len0) + fifo_mask = (p >= spk_len0) & (p < (spk_len0 + fifo_len0)) + fifo_src = (p - spk_len0) + self.spkcache_max_len + idx = torch.where(fifo_mask, fifo_src, idx) + + # chunk part (packed immediately after spk_len0 + fifo_len0) + chunk_mask = (p >= (spk_len0 + fifo_len0)) & (p < total_len0) + chunk_src = (p - spk_len0 - fifo_len0) + self.spkcache_max_len + self.fifo_max_len + idx = torch.where(chunk_mask, chunk_src, idx) + + idx_exp = idx.view(1, -1, 1).expand(1, -1, emb_dim) + packed = torch.gather(combined, dim=1, index=idx_exp) + + valid = (p < total_len0).view(1, -1, 1) + packed = packed * valid + + total_len = (spkcache_len + fifo_len + chunk_pre_encode_len).to(torch.int64) + return packed, total_len + + def forward( + self, + chunk: torch.Tensor, + chunk_len: torch.Tensor, + spkcache: torch.Tensor, + spkcache_len: torch.Tensor, + fifo: torch.Tensor, + fifo_len: torch.Tensor, + lc: torch.Tensor, + rc: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Pre-encode chunk (subsample to diar frames). + chunk_pre_encode, chunk_pre_encode_len = self.diar_model.encoder.pre_encode( + x=chunk, lengths=chunk_len + ) + chunk_pre_encode_len = chunk_pre_encode_len.to(torch.int64) + + # Pack [spkcache, fifo, chunk] into a contiguous sequence for the neural net. + packed, packed_len = self._pack_spkcache_fifo_chunk( + spkcache=spkcache, + spkcache_len=spkcache_len, + fifo=fifo, + fifo_len=fifo_len, + chunk_pre_encode=chunk_pre_encode, + chunk_pre_encode_len=chunk_pre_encode_len, + ) + + # Run FastConformer encoder (bypass pre_encode because `packed` is already in embedding space). + emb_seq, emb_seq_len = self.diar_model.frontend_encoder( + processed_signal=packed, + processed_signal_length=packed_len, + bypass_pre_encode=True, + ) + preds = self.diar_model.forward_infer(emb_seq=emb_seq, emb_seq_length=emb_seq_len) + + # Produce fixed-size per-step outputs (padded with zeros). + chunk_pred_len = torch.clamp( + (chunk_pre_encode_len - lc - rc), min=0, max=self.chunk_len_max + ).to(torch.int64) + + # Indices 0..chunk_len_max-1 + i = torch.arange(self.chunk_len_max, device=chunk.device, dtype=torch.long) + # Masks to avoid relying on dynamic slicing. + valid_i = (i < chunk_pred_len[0]).view(1, -1, 1) + + # Chunk embeddings used to update caches in C++: chunk_pre_encode[:, lc:lc+chunk_pred_len] + pos_chunk = (i + lc[0]).to(torch.long) + pos_chunk = torch.clamp(pos_chunk, min=0, max=self.chunk_total_diar - 1) + pos_chunk_exp = pos_chunk.view(1, -1, 1).expand(1, -1, chunk_pre_encode.size(-1)) + chunk_embs = torch.gather(chunk_pre_encode, dim=1, index=pos_chunk_exp) * valid_i + + # Chunk speaker posteriors: preds[:, spk_len+fifo_len+lc : + chunk_pred_len] + base = (spkcache_len + fifo_len + lc).to(torch.int64) + pos_pred = (i + base[0]).to(torch.long) + pos_pred = torch.clamp(pos_pred, min=0, max=self.total_max_len - 1) + pos_pred_exp = pos_pred.view(1, -1, 1).expand(1, -1, preds.size(-1)) + chunk_preds = torch.gather(preds, dim=1, index=pos_pred_exp) * valid_i + + return chunk_preds, chunk_embs, chunk_pred_len + + +def _export_programs(model, cfg: StreamingConfig, max_audio_sec: int) -> Tuple[Dict, Dict]: + programs: Dict[str, torch.export.ExportedProgram] = {} + + sample_rate = int(model._cfg.preprocessor.sample_rate) + window_stride = float(model._cfg.preprocessor.window_stride) + subsampling_factor = int(model.encoder.subsampling_factor) + feat_dim = int(model._cfg.preprocessor.features) + emb_dim = int(model._cfg.sortformer_modules.fc_d_model) + n_spk = int(model.sortformer_modules.n_spk) + negative_init_val = float(getattr(model, "negative_init_val", -99.0)) + + max_audio_samples = int(sample_rate * int(max_audio_sec)) + + # Export preprocessor + preprocessor_wrapper = PreprocessorWrapper(model.preprocessor) + preprocessor_wrapper.eval() + + sample_audio = torch.randn(max_audio_samples, dtype=torch.float32) + sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64) + + # NeMo feature extractors sometimes branch on CUDA availability (data-dependent paths). + old_cuda_is_available = torch.cuda.is_available + torch.cuda.is_available = lambda: False + programs["preprocessor"] = export( + preprocessor_wrapper, + (sample_audio, sample_length), + dynamic_shapes={ + "audio": {0: Dim("audio_len", min=1600, max=max_audio_samples)}, + "length": {}, + }, + strict=False, + ) + torch.cuda.is_available = old_cuda_is_available + + # Export model_step (fixed shapes) + step = SortformerStreamingStep(model, cfg) + step.eval() + + max_chunk_feat_frames = ( + (cfg.chunk_left_context + cfg.chunk_len + cfg.chunk_right_context) + * subsampling_factor + ) + + sample_chunk = torch.randn(1, max_chunk_feat_frames, feat_dim, dtype=torch.float32) + sample_chunk_len = torch.tensor([max_chunk_feat_frames], dtype=torch.int64) + + sample_spkcache = torch.zeros(1, cfg.spkcache_len, emb_dim, dtype=torch.float32) + sample_spkcache_len = torch.tensor([0], dtype=torch.int64) + + sample_fifo = torch.zeros(1, cfg.fifo_len, emb_dim, dtype=torch.float32) + sample_fifo_len = torch.tensor([0], dtype=torch.int64) + + sample_lc = torch.tensor([cfg.chunk_left_context], dtype=torch.int64) + sample_rc = torch.tensor([cfg.chunk_right_context], dtype=torch.int64) + + programs["model_step"] = export( + step, + ( + sample_chunk, + sample_chunk_len, + sample_spkcache, + sample_spkcache_len, + sample_fifo, + sample_fifo_len, + sample_lc, + sample_rc, + ), + strict=False, + ) + + metadata = { + "sample_rate": sample_rate, + "window_stride": window_stride, + "subsampling_factor": subsampling_factor, + "feat_dim": feat_dim, + "emb_dim": emb_dim, + "n_spk": n_spk, + "negative_init_val": negative_init_val, + "spkcache_len": int(cfg.spkcache_len), + "fifo_len": int(cfg.fifo_len), + "spkcache_update_period": int(cfg.spkcache_update_period), + "chunk_len": int(cfg.chunk_len), + "chunk_left_context": int(cfg.chunk_left_context), + "chunk_right_context": int(cfg.chunk_right_context), + "max_chunk_feat_frames": int(max_chunk_feat_frames), + } + + return programs, metadata + + +def _lower_to_executorch(programs: Dict, metadata: Dict): + constant_methods = dict(metadata) + et_prog = to_edge_transform_and_lower( + programs, + partitioner=[], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=constant_methods, + ) + return et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=False, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="nvidia/diar_streaming_sortformer_4spk-v2.1", + help="NeMo model name (from_pretrained) or path to a .nemo file.", + ) + parser.add_argument("--output-dir", type=str, default="./sortformer_diar_exports") + parser.add_argument( + "--max-audio-sec", + type=int, + default=60, + help="Max audio duration (sec) used to bound preprocessor dynamic shapes during export.", + ) + + # Streaming parameters (match NeMo examples by default) + parser.add_argument("--spkcache-len", type=int, default=188) + parser.add_argument("--fifo-len", type=int, default=188) + parser.add_argument("--spkcache-update-period", type=int, default=144) + parser.add_argument("--chunk-len", type=int, default=6) + parser.add_argument("--chunk-left-context", type=int, default=1) + parser.add_argument("--chunk-right-context", type=int, default=7) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + print(f"Loading diarization model: {args.model}") + model = _load_model(args.model) + + cfg = StreamingConfig( + spkcache_len=int(args.spkcache_len), + fifo_len=int(args.fifo_len), + spkcache_update_period=int(args.spkcache_update_period), + chunk_len=int(args.chunk_len), + chunk_left_context=int(args.chunk_left_context), + chunk_right_context=int(args.chunk_right_context), + ) + + # Apply streaming parameter overrides on the NeMo model (for metadata + consistency). + model.sortformer_modules.spkcache_len = cfg.spkcache_len + model.sortformer_modules.fifo_len = cfg.fifo_len + model.sortformer_modules.spkcache_update_period = cfg.spkcache_update_period + model.sortformer_modules.chunk_len = cfg.chunk_len + model.sortformer_modules.chunk_left_context = cfg.chunk_left_context + model.sortformer_modules.chunk_right_context = cfg.chunk_right_context + model.sortformer_modules._check_streaming_parameters() + + print("Exporting methods...") + programs, metadata = _export_programs(model, cfg, max_audio_sec=int(args.max_audio_sec)) + + print("Lowering to ExecuTorch (portable ops only)...") + et = _lower_to_executorch(programs, metadata) + + pte_path = os.path.join(args.output_dir, "model.pte") + with open(pte_path, "wb") as f: + et.write_to_file(f) + print(f"Saved: {pte_path}") + print(f"Size: {os.path.getsize(pte_path) / (1024 * 1024):.1f} MB") + + print("Done.") + + +if __name__ == "__main__": + main() + diff --git a/examples/models/diar_streaming_sortformer/main.cpp b/examples/models/diar_streaming_sortformer/main.cpp new file mode 100644 index 00000000000..34e8026842a --- /dev/null +++ b/examples/models/diar_streaming_sortformer/main.cpp @@ -0,0 +1,493 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +DEFINE_string(model_path, "model.pte", "Path to diarization model (.pte)."); +DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); +DEFINE_string( + data_path, + "", + "Path to data file (.ptd) for delegate data (optional)."); +DEFINE_double( + threshold, + 0.5, + "Speaker activity threshold in [0,1] used to form segments."); + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::EValue; + +namespace { + +struct Segment { + int speaker = -1; + double start_sec = 0.0; + double end_sec = 0.0; +}; + +int64_t get_int_constant(Module& model, const char* name) { + std::vector empty_inputs; + auto r = model.execute(name, empty_inputs); + if (!r.ok()) { + throw std::runtime_error(std::string("Failed to query constant method: ") + name); + } + return r.get()[0].toInt(); +} + +double get_double_constant(Module& model, const char* name) { + std::vector empty_inputs; + auto r = model.execute(name, empty_inputs); + if (!r.ok()) { + throw std::runtime_error(std::string("Failed to query constant method: ") + name); + } + return r.get()[0].toDouble(); +} + +void append_to_fixed_cache( + std::vector& cache, + int64_t& cache_len, + int64_t cache_max_len, + int64_t emb_dim, + const float* frames, + int64_t n_frames) { + if (n_frames <= 0) { + return; + } + + if (n_frames >= cache_max_len) { + // Keep only the most recent cache_max_len frames. + const float* src = frames + (n_frames - cache_max_len) * emb_dim; + std::memcpy(cache.data(), src, cache_max_len * emb_dim * sizeof(float)); + cache_len = cache_max_len; + return; + } + + const int64_t total = cache_len + n_frames; + if (total <= cache_max_len) { + std::memcpy( + cache.data() + cache_len * emb_dim, + frames, + n_frames * emb_dim * sizeof(float)); + cache_len = total; + return; + } + + // Drop the oldest `overflow` frames, shift left, append new. + const int64_t overflow = total - cache_max_len; + const int64_t remain = cache_len - overflow; + if (remain > 0) { + std::memmove( + cache.data(), + cache.data() + overflow * emb_dim, + remain * emb_dim * sizeof(float)); + } + std::memcpy( + cache.data() + remain * emb_dim, + frames, + n_frames * emb_dim * sizeof(float)); + cache_len = cache_max_len; +} + +std::vector segments_from_posteriors( + const std::vector& posteriors, + int64_t num_frames, + int64_t n_spk, + double frame_sec, + double threshold) { + std::vector out; + if (num_frames <= 0 || n_spk <= 0) { + return out; + } + + for (int64_t spk = 0; spk < n_spk; ++spk) { + bool in_seg = false; + int64_t start = 0; + for (int64_t t = 0; t < num_frames; ++t) { + const float p = posteriors[static_cast(t * n_spk + spk)]; + const bool active = static_cast(p) >= threshold; + if (active && !in_seg) { + in_seg = true; + start = t; + } else if (!active && in_seg) { + in_seg = false; + out.push_back( + Segment{static_cast(spk), start * frame_sec, t * frame_sec}); + } + } + if (in_seg) { + out.push_back(Segment{ + static_cast(spk), start * frame_sec, num_frames * frame_sec}); + } + } + + std::sort(out.begin(), out.end(), [](const Segment& a, const Segment& b) { + if (a.start_sec != b.start_sec) { + return a.start_sec < b.start_sec; + } + if (a.end_sec != b.end_sec) { + return a.end_sec < b.end_sec; + } + return a.speaker < b.speaker; + }); + return out; +} + +} // namespace + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_audio_path.empty()) { + ET_LOG(Error, "--audio_path is required."); + return 1; + } + + try { + ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); + auto model = std::make_unique(FLAGS_model_path, FLAGS_data_path); + + const int64_t model_sample_rate = get_int_constant(*model, "sample_rate"); + const double window_stride = get_double_constant(*model, "window_stride"); + const int64_t subsampling_factor = get_int_constant(*model, "subsampling_factor"); + const int64_t chunk_len = get_int_constant(*model, "chunk_len"); + const int64_t chunk_left_context = get_int_constant(*model, "chunk_left_context"); + const int64_t chunk_right_context = get_int_constant(*model, "chunk_right_context"); + const int64_t spkcache_max_len = get_int_constant(*model, "spkcache_len"); + const int64_t fifo_max_len = get_int_constant(*model, "fifo_len"); + const int64_t spkcache_update_period = + get_int_constant(*model, "spkcache_update_period"); + const int64_t emb_dim = get_int_constant(*model, "emb_dim"); + const int64_t n_spk = get_int_constant(*model, "n_spk"); + const double negative_init_val = get_double_constant(*model, "negative_init_val"); + + const int64_t chunk_feat_frames = chunk_len * subsampling_factor; + const int64_t left_feat_frames = chunk_left_context * subsampling_factor; + const int64_t right_feat_frames = chunk_right_context * subsampling_factor; + const int64_t max_chunk_feat_frames = + (chunk_left_context + chunk_len + chunk_right_context) * + subsampling_factor; + + const double frame_sec = window_stride * static_cast(subsampling_factor); + + ET_LOG( + Info, + "Model metadata: sample_rate=%lld, window_stride=%.6f, subsampling_factor=%lld, frame_sec=%.4f, " + "n_spk=%lld, emb_dim=%lld, chunk_len=%lld, lc=%lld, rc=%lld, max_chunk_feat_frames=%lld, " + "spkcache_len=%lld, fifo_len=%lld, spkcache_update_period=%lld", + static_cast(model_sample_rate), + window_stride, + static_cast(subsampling_factor), + frame_sec, + static_cast(n_spk), + static_cast(emb_dim), + static_cast(chunk_len), + static_cast(chunk_left_context), + static_cast(chunk_right_context), + static_cast(max_chunk_feat_frames), + static_cast(spkcache_max_len), + static_cast(fifo_max_len), + static_cast(spkcache_update_period)); + + // Load WAV and validate format. + auto header = executorch::extension::llm::load_wav_header(FLAGS_audio_path); + if (header.get() == nullptr) { + ET_LOG(Error, "Failed to load WAV header: %s", FLAGS_audio_path.c_str()); + return 1; + } + if (header->NumOfChan != 1) { + ET_LOG( + Error, + "Only mono WAV is supported. Got NumOfChan=%d", + static_cast(header->NumOfChan)); + return 1; + } + if (static_cast(header->SamplesPerSec) != model_sample_rate) { + ET_LOG( + Error, + "WAV sample rate (%d) != model sample rate (%lld). Resample the WAV first.", + static_cast(header->SamplesPerSec), + static_cast(model_sample_rate)); + return 1; + } + + ET_LOG(Info, "Loading WAV audio samples..."); + std::vector audio = + executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path); + + // Run preprocessor once for the full audio file. + ET_LOG(Info, "Running preprocessor..."); + std::vector audio_len_vec = {static_cast(audio.size())}; + auto audio_tensor = from_blob( + audio.data(), + {static_cast<::executorch::aten::SizesType>(audio.size())}, + ::executorch::aten::ScalarType::Float); + auto audio_len_tensor = from_blob( + audio_len_vec.data(), {1}, ::executorch::aten::ScalarType::Long); + + auto prep_result = model->execute( + "preprocessor", + std::vector{audio_tensor, audio_len_tensor}); + if (!prep_result.ok()) { + ET_LOG(Error, "preprocessor failed."); + return 1; + } + auto& prep_out = prep_result.get(); + auto features = prep_out[0].toTensor(); // [1, T_feat, feat_dim] + int64_t feat_len = prep_out[1].toTensor().const_data_ptr()[0]; + + const int64_t feat_dim = static_cast(features.sizes()[2]); + if (features.scalar_type() != ::executorch::aten::ScalarType::Float) { + ET_LOG(Error, "Expected float features from preprocessor."); + return 1; + } + + ET_LOG( + Info, + "Features shape: [1, %lld, %lld], feat_len=%lld", + static_cast(static_cast(features.sizes()[1])), + static_cast(feat_dim), + static_cast(feat_len)); + + const float* feat_ptr = features.const_data_ptr(); + + // Streaming caches (embeddings, not posteriors). + std::vector spkcache( + static_cast(spkcache_max_len * emb_dim), 0.0f); + std::vector fifo(static_cast(fifo_max_len * emb_dim), 0.0f); + int64_t spkcache_len = 0; + int64_t fifo_len = 0; + + // Accumulate diar posteriors per diar frame. + std::vector posteriors; // row-major: [T_diar, n_spk] + + std::vector chunk_feat_buf( + static_cast(max_chunk_feat_frames * feat_dim), 0.0f); + + int64_t stt_feat = 0; + int step_idx = 0; + while (stt_feat < feat_len) { + const int64_t left_offset = + std::min(left_feat_frames, stt_feat); + const int64_t end_feat = + std::min(stt_feat + chunk_feat_frames, feat_len); + const int64_t right_offset = + std::min(right_feat_frames, feat_len - end_feat); + + const int64_t chunk_start = stt_feat - left_offset; + const int64_t chunk_valid_frames = + (end_feat + right_offset) - chunk_start; + + // Prepare fixed-size chunk input with padding. + std::fill( + chunk_feat_buf.begin(), + chunk_feat_buf.end(), + static_cast(negative_init_val)); + for (int64_t i = 0; i < chunk_valid_frames; ++i) { + const float* src = feat_ptr + (chunk_start + i) * feat_dim; + float* dst = chunk_feat_buf.data() + i * feat_dim; + std::memcpy(dst, src, static_cast(feat_dim) * sizeof(float)); + } + + std::vector chunk_len_vec = {chunk_valid_frames}; + auto chunk_tensor = from_blob( + chunk_feat_buf.data(), + {static_cast<::executorch::aten::SizesType>(1), + static_cast<::executorch::aten::SizesType>(max_chunk_feat_frames), + static_cast<::executorch::aten::SizesType>(feat_dim)}, + ::executorch::aten::ScalarType::Float); + auto chunk_len_tensor = from_blob( + chunk_len_vec.data(), + {static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + + auto spkcache_tensor = from_blob( + spkcache.data(), + {static_cast<::executorch::aten::SizesType>(1), + static_cast<::executorch::aten::SizesType>(spkcache_max_len), + static_cast<::executorch::aten::SizesType>(emb_dim)}, + ::executorch::aten::ScalarType::Float); + std::vector spkcache_len_vec = {spkcache_len}; + auto spkcache_len_tensor = from_blob( + spkcache_len_vec.data(), + {static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + + auto fifo_tensor = from_blob( + fifo.data(), + {static_cast<::executorch::aten::SizesType>(1), + static_cast<::executorch::aten::SizesType>(fifo_max_len), + static_cast<::executorch::aten::SizesType>(emb_dim)}, + ::executorch::aten::ScalarType::Float); + std::vector fifo_len_vec = {fifo_len}; + auto fifo_len_tensor = from_blob( + fifo_len_vec.data(), + {static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + + // lc/rc in diar frames (post-subsampling), derived from feature-frame offsets. + const int64_t lc = static_cast( + std::llround(static_cast(left_offset) / + static_cast(subsampling_factor))); + const int64_t rc = (right_offset + subsampling_factor - 1) / subsampling_factor; + std::vector lc_vec = {lc}; + std::vector rc_vec = {rc}; + auto lc_tensor = + from_blob( + lc_vec.data(), + {static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + auto rc_tensor = + from_blob( + rc_vec.data(), + {static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + + ET_LOG(Info, "Processing step %d: feat[%lld:%lld] lc=%lld rc=%lld", step_idx, + static_cast(chunk_start), + static_cast(end_feat), + static_cast(lc), + static_cast(rc)); + auto step_result = model->execute( + "model_step", + std::vector{ + chunk_tensor, + chunk_len_tensor, + spkcache_tensor, + spkcache_len_tensor, + fifo_tensor, + fifo_len_tensor, + lc_tensor, + rc_tensor, + }); + if (!step_result.ok()) { + ET_LOG(Error, "model_step failed at step %d.", step_idx); + return 1; + } + + auto& step_out = step_result.get(); + auto chunk_preds = step_out[0].toTensor(); // [1, chunk_len_max, n_spk] + auto chunk_embs = step_out[1].toTensor(); // [1, chunk_len_max, emb_dim] + const int64_t chunk_pred_len = + step_out[2].toTensor().const_data_ptr()[0]; + + if (chunk_pred_len > 0) { + const float* preds_ptr = chunk_preds.const_data_ptr(); + const float* embs_ptr = chunk_embs.const_data_ptr(); + posteriors.resize( + posteriors.size() + static_cast(chunk_pred_len * n_spk)); + float* out_ptr = + posteriors.data() + (posteriors.size() - static_cast(chunk_pred_len * n_spk)); + std::memcpy( + out_ptr, + preds_ptr, + static_cast(chunk_pred_len * n_spk) * sizeof(float)); + + // Update FIFO + speaker cache (embedding-only). This approximates NeMo streaming + // cache management but does not implement Sortformer speaker-cache compression. + const int64_t fifo_len_before = fifo_len; + const int64_t fifo_len_after = fifo_len_before + chunk_pred_len; + + if (fifo_len_after <= fifo_max_len) { + std::memcpy( + fifo.data() + fifo_len_before * emb_dim, + embs_ptr, + static_cast(chunk_pred_len * emb_dim) * sizeof(float)); + fifo_len = fifo_len_after; + } else { + // Build a temporary FIFO buffer: [old_fifo_valid, new_chunk] + std::vector fifo_tmp( + static_cast(fifo_len_after * emb_dim), 0.0f); + if (fifo_len_before > 0) { + std::memcpy( + fifo_tmp.data(), + fifo.data(), + static_cast(fifo_len_before * emb_dim) * sizeof(float)); + } + std::memcpy( + fifo_tmp.data() + fifo_len_before * emb_dim, + embs_ptr, + static_cast(chunk_pred_len * emb_dim) * sizeof(float)); + + int64_t pop_out_len = spkcache_update_period; + pop_out_len = std::max( + pop_out_len, chunk_pred_len - fifo_max_len + fifo_len_before); + pop_out_len = std::min(pop_out_len, fifo_len_after); + + // Move the oldest `pop_out_len` frames from FIFO into speaker cache. + append_to_fixed_cache( + spkcache, + spkcache_len, + spkcache_max_len, + emb_dim, + fifo_tmp.data(), + pop_out_len); + + // Keep the remaining FIFO frames. + const int64_t new_fifo_len = fifo_len_after - pop_out_len; + if (new_fifo_len > 0) { + std::memcpy( + fifo.data(), + fifo_tmp.data() + pop_out_len * emb_dim, + static_cast(new_fifo_len * emb_dim) * sizeof(float)); + } + // Zero out tail for cleanliness. + if (new_fifo_len < fifo_max_len) { + std::fill( + fifo.begin() + static_cast(new_fifo_len * emb_dim), + fifo.end(), + 0.0f); + } + fifo_len = new_fifo_len; + } + } + + stt_feat = end_feat; + step_idx += 1; + } + + const int64_t num_diar_frames = static_cast(posteriors.size() / static_cast(n_spk)); + ET_LOG( + Info, + "Produced %lld diar frames (%.2f sec)", + static_cast(num_diar_frames), + num_diar_frames * frame_sec); + + auto segments = segments_from_posteriors( + posteriors, num_diar_frames, n_spk, frame_sec, FLAGS_threshold); + + std::cout << "Segments (threshold=" << FLAGS_threshold << ")\n"; + for (const auto& seg : segments) { + std::cout << "speaker_" << seg.speaker << " " + << seg.start_sec << " " << seg.end_sec << "\n"; + } + + return 0; + } catch (const std::exception& e) { + ET_LOG(Error, "Exception: %s", e.what()); + return 1; + } +}