diff --git a/Sources/FluidAudio/ASR/AsrManager.swift b/Sources/FluidAudio/ASR/AsrManager.swift index 0202baafc..bbe83b9e4 100644 --- a/Sources/FluidAudio/ASR/AsrManager.swift +++ b/Sources/FluidAudio/ASR/AsrManager.swift @@ -20,7 +20,7 @@ public final class AsrManager { internal var jointModel: MLModel? /// The AsrModels instance if initialized with models - private var asrModels: AsrModels? + internal var asrModels: AsrModels? internal let progressEmitter = ProgressEmitter() @@ -88,14 +88,16 @@ public final class AsrManager { } public var isAvailable: Bool { - let baseModelsReady = encoderModel != nil && decoderModel != nil && jointModel != nil - guard baseModelsReady else { return false } + let decoderReady = decoderModel != nil && jointModel != nil + guard decoderReady else { return false } if asrModels?.usesSplitFrontend == true { + // Split frontend: need both preprocessor and encoder + return preprocessorModel != nil && encoderModel != nil + } else { + // Fused frontend: preprocessor contains encoder return preprocessorModel != nil } - - return true } /// Initialize ASR Manager with pre-loaded models @@ -110,7 +112,10 @@ public final class AsrManager { self.jointModel = models.joint self.vocabulary = models.vocabulary - logger.info("Token duration optimization model loaded successfully") + // Recreate decoder states with the correct layer count for this model version + let layers = models.version.decoderLayers + self.microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) + self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers) logger.info("AsrManager initialized successfully with provided models") } @@ -277,19 +282,22 @@ public final class AsrManager { } public func resetState() { - microphoneDecoderState = TdtDecoderState.make() - systemDecoderState = TdtDecoderState.make() + let layers = asrModels?.version.decoderLayers ?? 2 + microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) + systemDecoderState = TdtDecoderState.make(decoderLayers: layers) Task { await sharedMLArrayCache.clear() } } public func cleanup() { + let layers = asrModels?.version.decoderLayers ?? 2 + asrModels = nil preprocessorModel = nil encoderModel = nil decoderModel = nil jointModel = nil // Reset decoder states using fresh allocations for deterministic behavior - microphoneDecoderState = TdtDecoderState.make() - systemDecoderState = TdtDecoderState.make() + microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers) + systemDecoderState = TdtDecoderState.make(decoderLayers: layers) // Release vocabulary boosting resources disableVocabularyBoosting() Task { await sharedMLArrayCache.clear() } @@ -310,9 +318,25 @@ public final class AsrManager { guard let models = asrModels, let decoder_ = decoderModel, let joint = jointModel else { throw ASRError.notInitialized } + + // Adapt config's encoderHiddenSize to match the loaded model version + // (e.g. default config uses 1024 but tdtCtc110m needs 512) + let adaptedConfig: ASRConfig + if config.encoderHiddenSize != models.version.encoderHiddenSize { + adaptedConfig = ASRConfig( + sampleRate: config.sampleRate, + tdtConfig: config.tdtConfig, + encoderHiddenSize: models.version.encoderHiddenSize, + streamingEnabled: config.streamingEnabled, + streamingThreshold: config.streamingThreshold + ) + } else { + adaptedConfig = config + } + switch models.version { - case .v2: - let decoder = TdtDecoderV2(config: config) + case .v2, .tdtCtc110m: + let decoder = TdtDecoderV2(config: adaptedConfig) return try await decoder.decodeWithTimings( encoderOutput: encoderOutput, encoderSequenceLength: encoderSequenceLength, @@ -325,7 +349,7 @@ public final class AsrManager { globalFrameOffset: globalFrameOffset ) case .v3: - let decoder = TdtDecoderV3(config: config) + let decoder = TdtDecoderV3(config: adaptedConfig) return try await decoder.decodeWithTimings( encoderOutput: encoderOutput, encoderSequenceLength: encoderSequenceLength, diff --git a/Sources/FluidAudio/ASR/AsrModels.swift b/Sources/FluidAudio/ASR/AsrModels.swift index b1c2b9870..67129c6bd 100644 --- a/Sources/FluidAudio/ASR/AsrModels.swift +++ b/Sources/FluidAudio/ASR/AsrModels.swift @@ -6,11 +6,46 @@ import OSLog public enum AsrModelVersion: Sendable { case v2 case v3 + /// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder + case tdtCtc110m var repo: Repo { switch self { case .v2: return .parakeetV2 case .v3: return .parakeet + case .tdtCtc110m: return .parakeetTdtCtc110m + } + } + + /// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model) + public var hasFusedEncoder: Bool { + switch self { + case .tdtCtc110m: return true + default: return false + } + } + + /// Encoder hidden dimension for this model version + public var encoderHiddenSize: Int { + switch self { + case .tdtCtc110m: return 512 + default: return 1024 + } + } + + /// Blank token ID for this model version + public var blankId: Int { + switch self { + case .v2, .tdtCtc110m: return 1024 + case .v3: return 8192 + } + } + + /// Number of LSTM layers in the decoder prediction network + public var decoderLayers: Int { + switch self { + case .tdtCtc110m: return 1 + default: return 2 } } } @@ -20,7 +55,8 @@ public struct AsrModels: Sendable { /// Required model names for ASR public static let requiredModelNames = ModelNames.ASR.requiredModels - public let encoder: MLModel + /// Separate encoder model (nil for fused models like tdtCtc110m where preprocessor includes encoder) + public let encoder: MLModel? public let preprocessor: MLModel public let decoder: MLModel public let joint: MLModel @@ -31,7 +67,7 @@ public struct AsrModels: Sendable { private static let logger = AppLogger(category: "AsrModels") public init( - encoder: MLModel, + encoder: MLModel?, preprocessor: MLModel, decoder: MLModel, joint: MLModel, @@ -48,8 +84,9 @@ public struct AsrModels: Sendable { self.version = version } + /// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused) public var usesSplitFrontend: Bool { - true + !version.hasFusedEncoder } } @@ -60,7 +97,15 @@ extension AsrModels { let computeUnits: MLComputeUnits } - private static func createModelSpecs(using config: MLModelConfiguration) -> [ModelSpec] { + private static func createModelSpecs( + using config: MLModelConfiguration, version: AsrModelVersion + ) -> [ModelSpec] { + if version.hasFusedEncoder { + // Fused preprocessor+encoder runs on ANE (it contains the conformer encoder) + return [ + ModelSpec(fileName: Names.preprocessorFile, computeUnits: config.computeUnits) + ] + } return [ // Preprocessor ops map to CPU-only across all platforms. XCode profiling shows // that 100% of the the operations map to the CPU anyways. @@ -78,7 +123,7 @@ extension AsrModels { private static func inferredVersion(from directory: URL) -> AsrModelVersion? { let directoryPath = directory.path.lowercased() - let knownVersions: [AsrModelVersion] = [.v2, .v3] + let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3] for version in knownVersions { if directoryPath.contains(version.repo.folderName.lowercased()) { @@ -118,7 +163,7 @@ extension AsrModels { let parentDirectory = directory.deletingLastPathComponent() // Load preprocessor and encoder first; decoder and joint are loaded below as well. - let specs = createModelSpecs(using: config) + let specs = createModelSpecs(using: config, version: version) var loadedModels: [String: MLModel] = [:] @@ -138,10 +183,13 @@ extension AsrModels { } } - guard let preprocessorModel = loadedModels[Names.preprocessorFile], - let encoderModel = loadedModels[Names.encoderFile] - else { - throw AsrModelsError.loadingFailed("Failed to load preprocessor or encoder model") + guard let preprocessorModel = loadedModels[Names.preprocessorFile] else { + throw AsrModelsError.loadingFailed("Failed to load preprocessor model") + } + let encoderModel = loadedModels[Names.encoderFile] // nil for fused models + + if !version.hasFusedEncoder && encoderModel == nil { + throw AsrModelsError.loadingFailed("Failed to load encoder model (required for split frontend)") } // Load decoder and joint as well @@ -185,18 +233,30 @@ extension AsrModels { do { let data = try Data(contentsOf: vocabPath) - let jsonDict = try JSONSerialization.jsonObject(with: data) as? [String: String] ?? [:] + let json = try JSONSerialization.jsonObject(with: data) var vocabulary: [Int: String] = [:] - for (key, value) in jsonDict { - if let tokenId = Int(key) { - vocabulary[tokenId] = value + if let jsonArray = json as? [String] { + // Array format (110m hybrid): index = token ID + for (index, token) in jsonArray.enumerated() { + vocabulary[index] = token + } + } else if let jsonDict = json as? [String: String] { + // Dictionary format (0.6B v2/v3): key = token ID string + for (key, value) in jsonDict { + if let tokenId = Int(key) { + vocabulary[tokenId] = value + } } + } else { + throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format") } logger.info("Loaded vocabulary with \(vocabulary.count) tokens from \(vocabPath.path)") return vocabulary + } catch let error as AsrModelsError { + throw error } catch { logger.error( "Failed to load or parse vocabulary file at \(vocabPath.path): \(error.localizedDescription)" @@ -324,13 +384,23 @@ extension AsrModels { let defaultUnits = defaultConfiguration().computeUnits - let specs: [DownloadSpec] = [ - // Preprocessor ops map to CPU-only across all platforms. - DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly), - DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits), - DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits), - DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits), - ] + let specs: [DownloadSpec] + if version.hasFusedEncoder { + specs = [ + // Fused preprocessor+encoder runs on ANE + DownloadSpec(fileName: Names.preprocessorFile, computeUnits: defaultUnits), + DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits), + DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits), + ] + } else { + specs = [ + // Preprocessor ops map to CPU-only across all platforms. + DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly), + DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits), + DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits), + DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits), + ] + } for spec in specs { _ = try await DownloadUtils.loadModels( @@ -365,7 +435,8 @@ extension AsrModels { public static func modelsExist(at directory: URL, version: AsrModelVersion) -> Bool { let fileManager = FileManager.default - let requiredFiles = ModelNames.ASR.requiredModels + let requiredFiles = + version.hasFusedEncoder ? ModelNames.ASR.requiredModelsFused : ModelNames.ASR.requiredModels // Check in the DownloadUtils repo structure let repoPath = repoPath(from: directory, version: version) @@ -397,12 +468,14 @@ extension AsrModels { let config = MLModelConfiguration() config.computeUnits = .cpuOnly - let modelsToValidate = [ + var modelsToValidate = [ ("Preprocessor", ModelNames.ASR.preprocessorFile), - ("Encoder", ModelNames.ASR.encoderFile), ("Decoder", ModelNames.ASR.decoderFile), ("Joint", ModelNames.ASR.jointFile), ] + if !version.hasFusedEncoder { + modelsToValidate.insert(("Encoder", ModelNames.ASR.encoderFile), at: 1) + } for (name, fileName) in modelsToValidate { let modelPath = repoPath.appendingPathComponent(fileName) diff --git a/Sources/FluidAudio/ASR/AsrTranscription.swift b/Sources/FluidAudio/ASR/AsrTranscription.swift index 1a878d254..f101c1158 100644 --- a/Sources/FluidAudio/ASR/AsrTranscription.swift +++ b/Sources/FluidAudio/ASR/AsrTranscription.swift @@ -88,7 +88,7 @@ extension AsrManager { let preprocessorAudioArray = preprocessorInput.featureValue(for: "audio_signal")?.multiArrayValue do { - guard let preprocessorModel = preprocessorModel, let encoderModel = encoderModel else { + guard let preprocessorModel = preprocessorModel else { throw ASRError.notInitialized } @@ -98,17 +98,24 @@ extension AsrManager { options: predictionOptions ) - let encoderInput = try prepareEncoderInput( - encoder: encoderModel, - preprocessorOutput: preprocessorOutput, - originalInput: preprocessorInput - ) - - try Task.checkCancellation() - let encoderOutputProvider = try await encoderModel.compatPrediction( - from: encoderInput, - options: predictionOptions - ) + let encoderOutputProvider: MLFeatureProvider + if let encoderModel = encoderModel { + // Split frontend: run separate encoder + let encoderInput = try prepareEncoderInput( + encoder: encoderModel, + preprocessorOutput: preprocessorOutput, + originalInput: preprocessorInput + ) + + try Task.checkCancellation() + encoderOutputProvider = try await encoderModel.compatPrediction( + from: encoderInput, + options: predictionOptions + ) + } else { + // Fused frontend: preprocessor output already contains encoder features + encoderOutputProvider = preprocessorOutput + } let rawEncoderOutput = try extractFeatureValue( from: encoderOutputProvider, key: "encoder", errorMessage: "Invalid encoder output") diff --git a/Sources/FluidAudio/ASR/AsrTypes.swift b/Sources/FluidAudio/ASR/AsrTypes.swift index 80ad5b3d2..c4dcf2950 100644 --- a/Sources/FluidAudio/ASR/AsrTypes.swift +++ b/Sources/FluidAudio/ASR/AsrTypes.swift @@ -6,6 +6,9 @@ public struct ASRConfig: Sendable { public let sampleRate: Int public let tdtConfig: TdtConfig + /// Encoder hidden dimension (1024 for 0.6B, 512 for 110m) + public let encoderHiddenSize: Int + /// Enable streaming mode for large files to reduce memory usage. /// When enabled, files larger than `streamingThreshold` samples will be processed /// using streaming to maintain constant memory usage. @@ -21,11 +24,13 @@ public struct ASRConfig: Sendable { public init( sampleRate: Int = 16000, tdtConfig: TdtConfig = .default, + encoderHiddenSize: Int = ASRConstants.encoderHiddenSize, streamingEnabled: Bool = true, streamingThreshold: Int = 480_000 ) { self.sampleRate = sampleRate self.tdtConfig = tdtConfig + self.encoderHiddenSize = encoderHiddenSize self.streamingEnabled = streamingEnabled self.streamingThreshold = streamingThreshold } diff --git a/Sources/FluidAudio/ASR/ChunkProcessor.swift b/Sources/FluidAudio/ASR/ChunkProcessor.swift index 29873d1a8..60a54b550 100644 --- a/Sources/FluidAudio/ASR/ChunkProcessor.swift +++ b/Sources/FluidAudio/ASR/ChunkProcessor.swift @@ -65,7 +65,9 @@ struct ChunkProcessor { var chunkStart = 0 var chunkIndex = 0 - var chunkDecoderState = TdtDecoderState.make() + var chunkDecoderState = TdtDecoderState.make( + decoderLayers: manager.asrModels?.version.decoderLayers ?? 2 + ) while chunkStart < totalSamples { try Task.checkCancellation() diff --git a/Sources/FluidAudio/ASR/TDT/EncoderFrameView.swift b/Sources/FluidAudio/ASR/TDT/EncoderFrameView.swift index fd797d638..4a7427589 100644 --- a/Sources/FluidAudio/ASR/TDT/EncoderFrameView.swift +++ b/Sources/FluidAudio/ASR/TDT/EncoderFrameView.swift @@ -16,7 +16,8 @@ struct EncoderFrameView { private let timeBaseOffset: Int private let basePointer: UnsafeMutablePointer - init(encoderOutput: MLMultiArray, validLength: Int) throws { + /// Initialize with explicit hidden size (for model-version-aware callers) + init(encoderOutput: MLMultiArray, validLength: Int, expectedHiddenSize: Int) throws { let shape = encoderOutput.shape.map { $0.intValue } guard shape.count == 3 else { throw ASRError.processingFailed("Invalid encoder output shape: \(shape)") @@ -25,11 +26,11 @@ struct EncoderFrameView { throw ASRError.processingFailed("Unsupported batch dimension: \(shape[0])") } - let hiddenSize = ASRConstants.encoderHiddenSize + let hiddenSize = expectedHiddenSize let axis1MatchesHidden = shape[1] == hiddenSize let axis2MatchesHidden = shape[2] == hiddenSize guard axis1MatchesHidden || axis2MatchesHidden else { - throw ASRError.processingFailed("Encoder hidden size mismatch: \(shape)") + throw ASRError.processingFailed("Encoder hidden size mismatch: \(shape), expected \(hiddenSize)") } self.hiddenAxis = axis1MatchesHidden ? 1 : 2 @@ -61,6 +62,15 @@ struct EncoderFrameView { } } + /// Convenience initializer using default encoder hidden size from ASRConstants + init(encoderOutput: MLMultiArray, validLength: Int) throws { + try self.init( + encoderOutput: encoderOutput, + validLength: validLength, + expectedHiddenSize: ASRConstants.encoderHiddenSize + ) + } + func copyFrame( at index: Int, into destination: UnsafeMutablePointer, diff --git a/Sources/FluidAudio/ASR/TDT/TdtDecoderState.swift b/Sources/FluidAudio/ASR/TDT/TdtDecoderState.swift index 7016ef721..b110ea262 100644 --- a/Sources/FluidAudio/ASR/TDT/TdtDecoderState.swift +++ b/Sources/FluidAudio/ASR/TDT/TdtDecoderState.swift @@ -24,15 +24,15 @@ struct TdtDecoderState: Sendable { /// - zero: Decoder exactly at the end of encoder frames var timeJump: Int? - init() throws { + init(decoderLayers: Int = 2) throws { // Use ANE-aligned arrays for optimal performance let decoderHiddenSize = ASRConstants.decoderHiddenSize hiddenState = try ANEOptimizer.createANEAlignedArray( - shape: [2, 1, NSNumber(value: decoderHiddenSize)], + shape: [NSNumber(value: decoderLayers), 1, NSNumber(value: decoderHiddenSize)], dataType: .float32 ) cellState = try ANEOptimizer.createANEAlignedArray( - shape: [2, 1, NSNumber(value: decoderHiddenSize)], + shape: [NSNumber(value: decoderLayers), 1, NSNumber(value: decoderHiddenSize)], dataType: .float32 ) @@ -41,9 +41,9 @@ struct TdtDecoderState: Sendable { cellState.resetData(to: 0) } - static func make() -> TdtDecoderState { + static func make(decoderLayers: Int = 2) -> TdtDecoderState { do { - return try TdtDecoderState() + return try TdtDecoderState(decoderLayers: decoderLayers) } catch { fatalError("Failed to allocate decoder state: \(error)") } diff --git a/Sources/FluidAudio/ASR/TDT/TdtDecoderV2.swift b/Sources/FluidAudio/ASR/TDT/TdtDecoderV2.swift index 561567dd0..7037db7d2 100644 --- a/Sources/FluidAudio/ASR/TDT/TdtDecoderV2.swift +++ b/Sources/FluidAudio/ASR/TDT/TdtDecoderV2.swift @@ -66,6 +66,10 @@ internal struct TdtDecoderV2 { consecutiveBlankLimit: tdt.consecutiveBlankLimit ) - return ASRConfig(sampleRate: config.sampleRate, tdtConfig: adaptedTdt) + return ASRConfig( + sampleRate: config.sampleRate, + tdtConfig: adaptedTdt, + encoderHiddenSize: config.encoderHiddenSize + ) } } diff --git a/Sources/FluidAudio/ASR/TDT/TdtDecoderV3.swift b/Sources/FluidAudio/ASR/TDT/TdtDecoderV3.swift index 4d97267b1..336f2e5b9 100644 --- a/Sources/FluidAudio/ASR/TDT/TdtDecoderV3.swift +++ b/Sources/FluidAudio/ASR/TDT/TdtDecoderV3.swift @@ -111,9 +111,15 @@ internal struct TdtDecoderV3 { return TdtHypothesis(decState: decoderState) } + // Use encoder hidden size from config (512 for 110m, 1024 for 0.6B) + let expectedEncoderHidden = config.encoderHiddenSize + // Build a stride-aware view so we can access encoder frames without extra copies let encoderFrames = try EncoderFrameView( - encoderOutput: encoderOutput, validLength: encoderSequenceLength) + encoderOutput: encoderOutput, + validLength: encoderSequenceLength, + expectedHiddenSize: expectedEncoderHidden + ) var hypothesis = TdtHypothesis(decState: decoderState) hypothesis.lastToken = decoderState.lastToken @@ -167,7 +173,7 @@ internal struct TdtDecoderV3 { reusableTargetLengthArray[0] = NSNumber(value: 1) // Preallocate joint input tensors and a reusable provider to avoid per-step allocations. - let encoderHidden = encoderFrames.hiddenSize + let encoderHidden = expectedEncoderHidden let decoderHidden = ASRConstants.decoderHiddenSize let reusableEncoderStep = try ANEOptimizer.createANEAlignedArray( shape: [1, NSNumber(value: encoderHidden), 1], @@ -191,9 +197,8 @@ internal struct TdtDecoderV3 { // Initialize decoder LSTM state for a fresh utterance // This ensures clean state when starting transcription if decoderState.lastToken == nil && decoderState.predictorOutput == nil { - let zero = TdtDecoderState.make() - decoderState.hiddenState.copyData(from: zero.hiddenState) - decoderState.cellState.copyData(from: zero.cellState) + decoderState.hiddenState.resetData(to: 0) + decoderState.cellState.resetData(to: 0) } // Prime the decoder with Start-of-Sequence token if needed @@ -881,7 +886,8 @@ internal struct TdtDecoderV3 { ) throws -> MLFeatureProvider { let encoderFrames = try EncoderFrameView( encoderOutput: encoderOutput, - validLength: encoderOutput.count) + validLength: encoderOutput.count, + expectedHiddenSize: config.encoderHiddenSize) let encoderStep = try ANEOptimizer.createANEAlignedArray( shape: [1, NSNumber(value: encoderFrames.hiddenSize), 1], dataType: .float32) diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index 7a4c71ee6..0fdb82023 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -16,6 +16,7 @@ public enum Repo: String, CaseIterable { case qwen3Asr = "FluidInference/qwen3-asr-0.6b-coreml/f32" case qwen3AsrInt8 = "FluidInference/qwen3-asr-0.6b-coreml/int8" case multilingualG2p = "FluidInference/charsiu-g2p-byt5-coreml" + case parakeetTdtCtc110m = "FluidInference/parakeet-tdt-ctc-110m-coreml" /// Repository slug (without owner) public var name: String { @@ -48,6 +49,8 @@ public enum Repo: String, CaseIterable { return "qwen3-asr-0.6b-coreml/int8" case .multilingualG2p: return "charsiu-g2p-byt5-coreml" + case .parakeetTdtCtc110m: + return "parakeet-tdt-ctc-110m-coreml" } } @@ -64,6 +67,8 @@ public enum Repo: String, CaseIterable { return "FluidInference/diar-streaming-sortformer-coreml" case .qwen3Asr, .qwen3AsrInt8: return "FluidInference/qwen3-asr-0.6b-coreml" + case .parakeetTdtCtc110m: + return "FluidInference/parakeet-tdt-ctc-110m-coreml" default: return "FluidInference/\(name)" } @@ -100,6 +105,8 @@ public enum Repo: String, CaseIterable { return "pocket-tts" case .multilingualG2p: return "charsiu-g2p-byt5" + case .parakeetTdtCtc110m: + return "parakeet-tdt-ctc-110m" default: return name } @@ -170,9 +177,24 @@ public enum ModelNames { jointFile, ] + /// Vocabulary filename for the 110m hybrid TDT-CTC model (JSON array format) + public static let vocabularyFileArray = "vocab.json" + + /// Required models for fused frontend (110m hybrid: preprocessor contains encoder) + public static let requiredModelsFused: Set = [ + preprocessorFile, + decoderFile, + jointFile, + ] + /// Get vocabulary filename for specific model version public static func vocabulary(for repo: Repo) -> String { - return vocabularyFile + switch repo { + case .parakeetTdtCtc110m: + return vocabularyFileArray + default: + return vocabularyFile + } } } @@ -429,6 +451,8 @@ public enum ModelNames { return ModelNames.VAD.requiredModels case .parakeet, .parakeetV2: return ModelNames.ASR.requiredModels + case .parakeetTdtCtc110m: + return ModelNames.ASR.requiredModelsFused case .parakeetCtc110m, .parakeetCtc06b: return ModelNames.CTC.requiredModels case .parakeetEou160, .parakeetEou320: diff --git a/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift b/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift index 4129d2e73..1212df825 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/AsrBenchmark.swift @@ -815,8 +815,11 @@ extension ASRBenchmark { modelVersion = .v2 case "v3", "3": modelVersion = .v3 + case "tdt-ctc-110m", "110m": + modelVersion = .tdtCtc110m default: - logger.error("Invalid model version: \(arguments[i + 1]). Use 'v2' or 'v3'") + logger.error( + "Invalid model version: \(arguments[i + 1]). Use 'v2', 'v3', or 'tdt-ctc-110m'") exit(1) } i += 1 @@ -834,7 +837,13 @@ extension ASRBenchmark { logger.info(" Max files: \(maxFiles?.description ?? "all")") } logger.info(" Output file: \(outputFile)") - logger.info(" Model version: \(modelVersion == .v2 ? "v2" : "v3")") + let versionLabel: String + switch modelVersion { + case .v2: versionLabel = "v2" + case .v3: versionLabel = "v3" + case .tdtCtc110m: versionLabel = "tdt-ctc-110m" + } + logger.info(" Model version: \(versionLabel)") logger.info(" Debug mode: \(debugMode ? "enabled" : "disabled")") logger.info(" Auto-download: \(autoDownload ? "enabled" : "disabled")") logger.info(" Test streaming: \(testStreaming ? "enabled" : "disabled")") @@ -856,9 +865,11 @@ extension ASRBenchmark { let benchmark = ASRBenchmark(config: config) - // Initialize ASR manager with fast benchmark preset + // Initialize ASR manager with model-version-aware config + let tdtConfig = TdtConfig(blankId: modelVersion.blankId) let asrConfig = ASRConfig( - tdtConfig: TdtConfig() + tdtConfig: tdtConfig, + encoderHiddenSize: modelVersion.encoderHiddenSize ) let asrManager = AsrManager(config: asrConfig) @@ -912,10 +923,7 @@ extension ASRBenchmark { if ProcessInfo.processInfo.environment["CI"] != nil { logger.debug("🔍 CI Debug Information:") - let modelsDir = FileManager.default.homeDirectoryForCurrentUser - .appendingPathComponent( - "Library/Application Support/FluidAudio/Models/parakeet-tdt-0.6b-\(modelVersion == .v2 ? "v2" : "v3")-coreml" - ) + let modelsDir = AsrModels.defaultCacheDirectory(for: modelVersion) logger.debug("Models directory: \(modelsDir.path)") logger.debug( " Directory exists: \(FileManager.default.fileExists(atPath: modelsDir.path))" @@ -1115,7 +1123,7 @@ extension ASRBenchmark { --max-files Maximum number of files to process (default: all) --single-file Process only a specific file (e.g., 1089-134686-0011) --output Output JSON file path (default: asr_benchmark_results.json) - --model-version ASR model version to use: v2 or v3 (default: v3) + --model-version ASR model version to use: v2, v3, or tdt-ctc-110m (default: v3) --debug Enable debug logging --auto-download Automatically download LibriSpeech dataset (default) --no-auto-download Disable automatic dataset download diff --git a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift index 3ce6f5c2f..0d19f83b2 100644 --- a/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift +++ b/Sources/FluidAudioCLI/Commands/ASR/TranscribeCommand.swift @@ -212,6 +212,7 @@ enum TranscribeCommand { var outputJsonPath: String? var modelVersion: AsrModelVersion = .v3 // Default to v3 var customVocabPath: String? + var modelDir: String? // Parse options var i = 1 @@ -238,12 +239,20 @@ enum TranscribeCommand { modelVersion = .v2 case "v3", "3": modelVersion = .v3 + case "tdt-ctc-110m", "110m": + modelVersion = .tdtCtc110m default: - logger.error("Invalid model version: \(arguments[i + 1]). Use 'v2' or 'v3'") + logger.error( + "Invalid model version: \(arguments[i + 1]). Use 'v2', 'v3', or 'tdt-ctc-110m'") exit(1) } i += 1 } + case "--model-dir": + if i + 1 < arguments.count { + modelDir = arguments[i + 1] + i += 1 + } case "--custom-vocab": if i + 1 < arguments.count { customVocabPath = arguments[i + 1] @@ -266,19 +275,31 @@ enum TranscribeCommand { logger.info("Using batch mode with direct processing\n") await testBatchTranscription( audioFile: audioFile, showMetadata: showMetadata, wordTimestamps: wordTimestamps, - outputJsonPath: outputJsonPath, modelVersion: modelVersion, customVocabPath: customVocabPath) + outputJsonPath: outputJsonPath, modelVersion: modelVersion, customVocabPath: customVocabPath, + modelDir: modelDir) } } /// Test batch transcription using AsrManager directly private static func testBatchTranscription( audioFile: String, showMetadata: Bool, wordTimestamps: Bool, outputJsonPath: String?, - modelVersion: AsrModelVersion, customVocabPath: String? + modelVersion: AsrModelVersion, customVocabPath: String?, modelDir: String? = nil ) async { do { // Initialize ASR models - let models = try await AsrModels.downloadAndLoad(version: modelVersion) - let asrManager = AsrManager(config: .default) + let models: AsrModels + if let modelDir = modelDir { + let dir = URL(fileURLWithPath: modelDir) + models = try await AsrModels.load(from: dir, version: modelVersion) + } else { + models = try await AsrModels.downloadAndLoad(version: modelVersion) + } + let tdtConfig = TdtConfig(blankId: modelVersion.blankId) + let asrConfig = ASRConfig( + tdtConfig: tdtConfig, + encoderHiddenSize: modelVersion.encoderHiddenSize + ) + let asrManager = AsrManager(config: asrConfig) try await asrManager.initialize(models: models) logger.info("ASR Manager initialized successfully") @@ -385,7 +406,12 @@ enum TranscribeCommand { if let outputJsonPath = outputJsonPath { let wordTimings = WordTimingMerger.mergeTokensIntoWords(result.tokenTimings ?? []) - let modelVersionLabel = modelVersion == .v2 ? "v2" : "v3" + let modelVersionLabel: String + switch modelVersion { + case .v2: modelVersionLabel = "v2" + case .v3: modelVersionLabel = "v3" + case .tdtCtc110m: modelVersionLabel = "tdt-ctc-110m" + } let output = TranscriptionJSONOutput( audioFile: audioFile, mode: "batch", @@ -634,7 +660,12 @@ enum TranscribeCommand { let snapshot = await tracker.metadataSnapshot() let wordTimings = WordTimingMerger.mergeTokensIntoWords(snapshot?.timings ?? []) let latestUpdate = await tracker.latestUpdateSnapshot() - let modelVersionLabel = modelVersion == .v2 ? "v2" : "v3" + let modelVersionLabel: String + switch modelVersion { + case .v2: modelVersionLabel = "v2" + case .v3: modelVersionLabel = "v3" + case .tdtCtc110m: modelVersionLabel = "tdt-ctc-110m" + } let output = TranscriptionJSONOutput( audioFile: audioFile, mode: "streaming", @@ -733,7 +764,8 @@ enum TranscribeCommand { --metadata Show confidence, start time, and end time in results --word-timestamps Show word-level timestamps for each word in the transcription --output-json Save full transcription result to JSON (includes word timings) - --model-version ASR model version to use: v2 or v3 (default: v2) + --model-version ASR model version: v2, v3, or tdt-ctc-110m (default: v3) + --model-dir Path to local model directory (skips download) --custom-vocab Apply vocabulary boosting using terms from file (batch mode only) Examples: