Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions Sources/FluidAudio/ASR/AsrManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public final class AsrManager {
internal var jointModel: MLModel?

/// The AsrModels instance if initialized with models
private var asrModels: AsrModels?
internal var asrModels: AsrModels?

internal let progressEmitter = ProgressEmitter()

Expand Down Expand Up @@ -88,14 +88,16 @@ public final class AsrManager {
}

public var isAvailable: Bool {
let baseModelsReady = encoderModel != nil && decoderModel != nil && jointModel != nil
guard baseModelsReady else { return false }
let decoderReady = decoderModel != nil && jointModel != nil
guard decoderReady else { return false }

if asrModels?.usesSplitFrontend == true {
// Split frontend: need both preprocessor and encoder
return preprocessorModel != nil && encoderModel != nil
} else {
// Fused frontend: preprocessor contains encoder
return preprocessorModel != nil
}

return true
}

/// Initialize ASR Manager with pre-loaded models
Expand All @@ -110,7 +112,10 @@ public final class AsrManager {
self.jointModel = models.joint
self.vocabulary = models.vocabulary

logger.info("Token duration optimization model loaded successfully")
// Recreate decoder states with the correct layer count for this model version
let layers = models.version.decoderLayers
self.microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
self.systemDecoderState = TdtDecoderState.make(decoderLayers: layers)

logger.info("AsrManager initialized successfully with provided models")
}
Expand Down Expand Up @@ -277,19 +282,22 @@ public final class AsrManager {
}

public func resetState() {
microphoneDecoderState = TdtDecoderState.make()
systemDecoderState = TdtDecoderState.make()
let layers = asrModels?.version.decoderLayers ?? 2
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
Task { await sharedMLArrayCache.clear() }
}

public func cleanup() {
let layers = asrModels?.version.decoderLayers ?? 2
asrModels = nil
preprocessorModel = nil
encoderModel = nil
decoderModel = nil
jointModel = nil
// Reset decoder states using fresh allocations for deterministic behavior
microphoneDecoderState = TdtDecoderState.make()
systemDecoderState = TdtDecoderState.make()
microphoneDecoderState = TdtDecoderState.make(decoderLayers: layers)
systemDecoderState = TdtDecoderState.make(decoderLayers: layers)
// Release vocabulary boosting resources
disableVocabularyBoosting()
Task { await sharedMLArrayCache.clear() }
Expand All @@ -310,9 +318,25 @@ public final class AsrManager {
guard let models = asrModels, let decoder_ = decoderModel, let joint = jointModel else {
throw ASRError.notInitialized
}

// Adapt config's encoderHiddenSize to match the loaded model version
// (e.g. default config uses 1024 but tdtCtc110m needs 512)
let adaptedConfig: ASRConfig
if config.encoderHiddenSize != models.version.encoderHiddenSize {
adaptedConfig = ASRConfig(
sampleRate: config.sampleRate,
tdtConfig: config.tdtConfig,
encoderHiddenSize: models.version.encoderHiddenSize,
streamingEnabled: config.streamingEnabled,
streamingThreshold: config.streamingThreshold
)
} else {
adaptedConfig = config
}

switch models.version {
case .v2:
let decoder = TdtDecoderV2(config: config)
case .v2, .tdtCtc110m:
let decoder = TdtDecoderV2(config: adaptedConfig)
return try await decoder.decodeWithTimings(
encoderOutput: encoderOutput,
encoderSequenceLength: encoderSequenceLength,
Expand All @@ -325,7 +349,7 @@ public final class AsrManager {
globalFrameOffset: globalFrameOffset
)
case .v3:
let decoder = TdtDecoderV3(config: config)
let decoder = TdtDecoderV3(config: adaptedConfig)
return try await decoder.decodeWithTimings(
encoderOutput: encoderOutput,
encoderSequenceLength: encoderSequenceLength,
Expand Down
121 changes: 97 additions & 24 deletions Sources/FluidAudio/ASR/AsrModels.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,46 @@ import OSLog
public enum AsrModelVersion: Sendable {
case v2
case v3
/// 110M parameter hybrid TDT-CTC model with fused preprocessor+encoder
case tdtCtc110m

var repo: Repo {
switch self {
case .v2: return .parakeetV2
case .v3: return .parakeet
case .tdtCtc110m: return .parakeetTdtCtc110m
}
}

/// Whether this model version uses a fused preprocessor+encoder (no separate Encoder model)
public var hasFusedEncoder: Bool {
switch self {
case .tdtCtc110m: return true
default: return false
}
}

/// Encoder hidden dimension for this model version
public var encoderHiddenSize: Int {
switch self {
case .tdtCtc110m: return 512
default: return 1024
}
}

/// Blank token ID for this model version
public var blankId: Int {
switch self {
case .v2, .tdtCtc110m: return 1024
case .v3: return 8192
}
}

/// Number of LSTM layers in the decoder prediction network
public var decoderLayers: Int {
switch self {
case .tdtCtc110m: return 1
default: return 2
}
}
}
Expand All @@ -20,7 +55,8 @@ public struct AsrModels: Sendable {
/// Required model names for ASR
public static let requiredModelNames = ModelNames.ASR.requiredModels

public let encoder: MLModel
/// Separate encoder model (nil for fused models like tdtCtc110m where preprocessor includes encoder)
public let encoder: MLModel?
public let preprocessor: MLModel
public let decoder: MLModel
public let joint: MLModel
Expand All @@ -31,7 +67,7 @@ public struct AsrModels: Sendable {
private static let logger = AppLogger(category: "AsrModels")

public init(
encoder: MLModel,
encoder: MLModel?,
preprocessor: MLModel,
decoder: MLModel,
joint: MLModel,
Expand All @@ -48,8 +84,9 @@ public struct AsrModels: Sendable {
self.version = version
}

/// Whether this model uses a separate preprocessor and encoder (true for 0.6B, false for 110m fused)
public var usesSplitFrontend: Bool {
true
!version.hasFusedEncoder
}
}

Expand All @@ -60,7 +97,15 @@ extension AsrModels {
let computeUnits: MLComputeUnits
}

private static func createModelSpecs(using config: MLModelConfiguration) -> [ModelSpec] {
private static func createModelSpecs(
using config: MLModelConfiguration, version: AsrModelVersion
) -> [ModelSpec] {
if version.hasFusedEncoder {
// Fused preprocessor+encoder runs on ANE (it contains the conformer encoder)
return [
ModelSpec(fileName: Names.preprocessorFile, computeUnits: config.computeUnits)
]
}
return [
// Preprocessor ops map to CPU-only across all platforms. XCode profiling shows
// that 100% of the the operations map to the CPU anyways.
Expand All @@ -78,7 +123,7 @@ extension AsrModels {

private static func inferredVersion(from directory: URL) -> AsrModelVersion? {
let directoryPath = directory.path.lowercased()
let knownVersions: [AsrModelVersion] = [.v2, .v3]
let knownVersions: [AsrModelVersion] = [.tdtCtc110m, .v2, .v3]

for version in knownVersions {
if directoryPath.contains(version.repo.folderName.lowercased()) {
Expand Down Expand Up @@ -118,7 +163,7 @@ extension AsrModels {

let parentDirectory = directory.deletingLastPathComponent()
// Load preprocessor and encoder first; decoder and joint are loaded below as well.
let specs = createModelSpecs(using: config)
let specs = createModelSpecs(using: config, version: version)

var loadedModels: [String: MLModel] = [:]

Expand All @@ -138,10 +183,13 @@ extension AsrModels {
}
}

guard let preprocessorModel = loadedModels[Names.preprocessorFile],
let encoderModel = loadedModels[Names.encoderFile]
else {
throw AsrModelsError.loadingFailed("Failed to load preprocessor or encoder model")
guard let preprocessorModel = loadedModels[Names.preprocessorFile] else {
throw AsrModelsError.loadingFailed("Failed to load preprocessor model")
}
let encoderModel = loadedModels[Names.encoderFile] // nil for fused models

if !version.hasFusedEncoder && encoderModel == nil {
throw AsrModelsError.loadingFailed("Failed to load encoder model (required for split frontend)")
}

// Load decoder and joint as well
Expand Down Expand Up @@ -185,18 +233,30 @@ extension AsrModels {

do {
let data = try Data(contentsOf: vocabPath)
let jsonDict = try JSONSerialization.jsonObject(with: data) as? [String: String] ?? [:]
let json = try JSONSerialization.jsonObject(with: data)

var vocabulary: [Int: String] = [:]

for (key, value) in jsonDict {
if let tokenId = Int(key) {
vocabulary[tokenId] = value
if let jsonArray = json as? [String] {
// Array format (110m hybrid): index = token ID
for (index, token) in jsonArray.enumerated() {
vocabulary[index] = token
}
} else if let jsonDict = json as? [String: String] {
// Dictionary format (0.6B v2/v3): key = token ID string
for (key, value) in jsonDict {
if let tokenId = Int(key) {
vocabulary[tokenId] = value
}
}
} else {
throw AsrModelsError.loadingFailed("Vocabulary file has unexpected format")
}

logger.info("Loaded vocabulary with \(vocabulary.count) tokens from \(vocabPath.path)")
return vocabulary
} catch let error as AsrModelsError {
throw error
} catch {
logger.error(
"Failed to load or parse vocabulary file at \(vocabPath.path): \(error.localizedDescription)"
Expand Down Expand Up @@ -324,13 +384,23 @@ extension AsrModels {

let defaultUnits = defaultConfiguration().computeUnits

let specs: [DownloadSpec] = [
// Preprocessor ops map to CPU-only across all platforms.
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly),
DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
]
let specs: [DownloadSpec]
if version.hasFusedEncoder {
specs = [
// Fused preprocessor+encoder runs on ANE
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
]
} else {
specs = [
// Preprocessor ops map to CPU-only across all platforms.
DownloadSpec(fileName: Names.preprocessorFile, computeUnits: .cpuOnly),
DownloadSpec(fileName: Names.encoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.decoderFile, computeUnits: defaultUnits),
DownloadSpec(fileName: Names.jointFile, computeUnits: defaultUnits),
]
}

for spec in specs {
_ = try await DownloadUtils.loadModels(
Expand Down Expand Up @@ -365,7 +435,8 @@ extension AsrModels {

public static func modelsExist(at directory: URL, version: AsrModelVersion) -> Bool {
let fileManager = FileManager.default
let requiredFiles = ModelNames.ASR.requiredModels
let requiredFiles =
version.hasFusedEncoder ? ModelNames.ASR.requiredModelsFused : ModelNames.ASR.requiredModels

// Check in the DownloadUtils repo structure
let repoPath = repoPath(from: directory, version: version)
Expand Down Expand Up @@ -397,12 +468,14 @@ extension AsrModels {
let config = MLModelConfiguration()
config.computeUnits = .cpuOnly

let modelsToValidate = [
var modelsToValidate = [
("Preprocessor", ModelNames.ASR.preprocessorFile),
("Encoder", ModelNames.ASR.encoderFile),
("Decoder", ModelNames.ASR.decoderFile),
("Joint", ModelNames.ASR.jointFile),
]
if !version.hasFusedEncoder {
modelsToValidate.insert(("Encoder", ModelNames.ASR.encoderFile), at: 1)
}

for (name, fileName) in modelsToValidate {
let modelPath = repoPath.appendingPathComponent(fileName)
Expand Down
31 changes: 19 additions & 12 deletions Sources/FluidAudio/ASR/AsrTranscription.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ extension AsrManager {
let preprocessorAudioArray = preprocessorInput.featureValue(for: "audio_signal")?.multiArrayValue

do {
guard let preprocessorModel = preprocessorModel, let encoderModel = encoderModel else {
guard let preprocessorModel = preprocessorModel else {
throw ASRError.notInitialized
}

Expand All @@ -98,17 +98,24 @@ extension AsrManager {
options: predictionOptions
)

let encoderInput = try prepareEncoderInput(
encoder: encoderModel,
preprocessorOutput: preprocessorOutput,
originalInput: preprocessorInput
)

try Task.checkCancellation()
let encoderOutputProvider = try await encoderModel.compatPrediction(
from: encoderInput,
options: predictionOptions
)
let encoderOutputProvider: MLFeatureProvider
if let encoderModel = encoderModel {
// Split frontend: run separate encoder
let encoderInput = try prepareEncoderInput(
encoder: encoderModel,
preprocessorOutput: preprocessorOutput,
originalInput: preprocessorInput
)

try Task.checkCancellation()
encoderOutputProvider = try await encoderModel.compatPrediction(
from: encoderInput,
options: predictionOptions
)
} else {
// Fused frontend: preprocessor output already contains encoder features
encoderOutputProvider = preprocessorOutput
}

let rawEncoderOutput = try extractFeatureValue(
from: encoderOutputProvider, key: "encoder", errorMessage: "Invalid encoder output")
Expand Down
5 changes: 5 additions & 0 deletions Sources/FluidAudio/ASR/AsrTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ public struct ASRConfig: Sendable {
public let sampleRate: Int
public let tdtConfig: TdtConfig

/// Encoder hidden dimension (1024 for 0.6B, 512 for 110m)
public let encoderHiddenSize: Int

/// Enable streaming mode for large files to reduce memory usage.
/// When enabled, files larger than `streamingThreshold` samples will be processed
/// using streaming to maintain constant memory usage.
Expand All @@ -21,11 +24,13 @@ public struct ASRConfig: Sendable {
public init(
sampleRate: Int = 16000,
tdtConfig: TdtConfig = .default,
encoderHiddenSize: Int = ASRConstants.encoderHiddenSize,
streamingEnabled: Bool = true,
streamingThreshold: Int = 480_000
) {
self.sampleRate = sampleRate
self.tdtConfig = tdtConfig
self.encoderHiddenSize = encoderHiddenSize
self.streamingEnabled = streamingEnabled
self.streamingThreshold = streamingThreshold
}
Expand Down
Loading
Loading