diff --git a/Sources/FluidAudio/ASR/CTC/ARPALanguageModel.swift b/Sources/FluidAudio/ASR/CTC/ARPALanguageModel.swift new file mode 100644 index 000000000..bf3907183 --- /dev/null +++ b/Sources/FluidAudio/ASR/CTC/ARPALanguageModel.swift @@ -0,0 +1,137 @@ +import Foundation + +/// ARPA n-gram language model for CTC beam search rescoring. +/// +/// Loads unigrams and bigrams from an ARPA text file. Log10 probabilities +/// are converted to natural log for direct combination with CTC log-softmax output. +/// +/// Usage: +/// ```swift +/// let lm = try ARPALanguageModel.load(from: arpaFileURL) +/// let score = lm.score(word: "runway", prev: "cleared") +/// ``` +/// +/// - Note: Only plain-text ARPA files are supported (not gzipped or binary KenLM). +/// Trigrams and higher-order n-grams are ignored. +public struct ARPALanguageModel: Sendable { + + public struct Entry: Sendable { + /// Log-probability in natural log (converted from log10) + public let logProb: Float + /// Backoff weight in natural log + public let backoff: Float + } + + /// Conversion factor from log10 to natural log + public static let log10ToNat: Float = 2.302585 + + /// Fallback log-probability for out-of-vocabulary words (≈ log(1e-10)) + public static let unkLogProb: Float = -23.026 + + /// Unigram entries keyed by word + public var unigrams: [String: Entry] = [:] + /// Bigram entries keyed by context word → target word + public var bigrams: [String: [String: Entry]] = [:] + + public init() {} + + /// Load an ARPA language model from a text file. + /// + /// Parses unigram and bigram sections. Trigrams and higher are skipped. + /// Log10 probabilities are converted to natural log. + /// + /// - Parameter url: Path to an ARPA-format text file. + /// - Returns: A populated `ARPALanguageModel`. + /// - Throws: If the file cannot be opened. + public static func load(from url: URL) throws -> ARPALanguageModel { + guard let reader = ARPALineReader(url: url) else { + throw ARPAError.cannotOpen(url.path) + } + var lm = ARPALanguageModel() + var section = "" + while let line = reader.readLine() { + if line.isEmpty || line.hasPrefix("\\data\\") { continue } + if line.hasPrefix("\\") { + if line == "\\end\\" { break } + section = line + continue + } + let parts = line.components(separatedBy: "\t") + guard let log10prob = Float(parts[0]) else { continue } + let prob = log10prob * log10ToNat + if section == "\\1-grams:", parts.count >= 2 { + let word = parts[1] + let backoff = parts.count >= 3 ? (Float(parts[2]) ?? 0.0) * log10ToNat : 0.0 + lm.unigrams[word] = Entry(logProb: prob, backoff: backoff) + } else if section == "\\2-grams:", parts.count >= 3 { + let ctx = parts[1] + let word = parts[2] + let backoff = parts.count >= 4 ? (Float(parts[3]) ?? 0.0) * log10ToNat : 0.0 + lm.bigrams[ctx, default: [:]][word] = Entry(logProb: prob, backoff: backoff) + } + } + return lm + } + + /// Score P(word | prev) in natural log, backing off to unigram if bigram is absent. + /// + /// - Parameters: + /// - word: The target word. + /// - prev: The preceding context word, or nil for unigram-only scoring. + /// - Returns: Natural log probability. + public func score(word: String, prev: String?) -> Float { + if let p = prev, let bi = bigrams[p]?[word] { return bi.logProb } + let backoff = prev.flatMap { unigrams[$0]?.backoff } ?? 0.0 + return backoff + (unigrams[word]?.logProb ?? ARPALanguageModel.unkLogProb) + } +} + +// MARK: - Errors + +public enum ARPAError: Error, LocalizedError { + case cannotOpen(String) + + public var errorDescription: String? { + switch self { + case .cannotOpen(let path): + return "Cannot open ARPA file: \(path)" + } + } +} + +// MARK: - Line Reader + +/// Streaming line reader for efficient ARPA file parsing. +final class ARPALineReader { + private let fileHandle: FileHandle + private var buffer = Data() + private let chunkSize = 65_536 + private var eof = false + + init?(url: URL) { + guard let fh = FileHandle(forReadingAtPath: url.path) else { return nil } + fileHandle = fh + } + + deinit { fileHandle.closeFile() } + + func readLine() -> String? { + while true { + if let nl = buffer.firstIndex(of: UInt8(ascii: "\n")) { + let slice = buffer[buffer.startIndex.. String { + var ids: [Int] = [] + var prev = -1 + for frame in logProbs { + guard !frame.isEmpty else { continue } + var bestIdx = 0 + var bestVal = frame[0] + for v in 1.. bestVal { + bestVal = frame[v] + bestIdx = v + } + } + if bestIdx != blankId && bestIdx != prev { ids.append(bestIdx) } + prev = bestIdx + } + return decodeCtcTokenIds(ids, vocabulary: vocabulary) +} + +/// Greedy CTC decode from an MLMultiArray of shape [1, T, V]. +/// +/// - Parameters: +/// - logProbs: MLMultiArray of shape [1, T, V] containing log-probabilities. +/// - vocabulary: Token vocabulary mapping token ID → string. +/// - blankId: CTC blank token index (default 1024). +/// - Returns: Decoded text string. +public func ctcGreedyDecode( + logProbs: MLMultiArray, + vocabulary: [Int: String], + blankId: Int = 1024 +) -> String { + let timeSteps = logProbs.shape[1].intValue + let vocabSize = logProbs.shape[2].intValue + let ptr = logProbs.dataPointer.assumingMemoryBound(to: Float32.self) + var ids: [Int] = [] + var prev = -1 + for t in 0.. bestVal { + bestVal = x + bestIdx = v + } + } + if bestIdx != blankId && bestIdx != prev { ids.append(bestIdx) } + prev = bestIdx + } + return decodeCtcTokenIds(ids, vocabulary: vocabulary) +} + +// MARK: - CTC Beam Search + +/// A single hypothesis in the CTC beam search. +public struct CtcBeam { + public var prefix: [Int] + public var pBlank: Float + public var pNonBlank: Float + public var lmScore: Float + public var wordPieces: [String] + public var prevWord: String? + + public var totalAcoustic: Float { logAddExp(pBlank, pNonBlank) } + public var total: Float { totalAcoustic + lmScore } + public var lastToken: Int? { prefix.last } + + public init( + prefix: [Int], pBlank: Float, pNonBlank: Float, + lmScore: Float, wordPieces: [String], prevWord: String? + ) { + self.prefix = prefix + self.pBlank = pBlank + self.pNonBlank = pNonBlank + self.lmScore = lmScore + self.wordPieces = wordPieces + self.prevWord = prevWord + } +} + +/// CTC prefix beam search with optional ARPA language model rescoring. +/// +/// Uses corrected repeat-token handling (Graves 2006 fix): +/// - Non-blank path continues: `p_nb(l) += p_nb(l) * P(c)` +/// - Blank path creates new token: `p_nb(l+c) += p_b(l) * P(c)` +/// +/// Word-level LM scores are applied at SentencePiece word boundaries (`▁` prefix). +/// +/// - Parameters: +/// - logProbs: Per-frame log-probabilities, shape [T][V]. +/// - vocabulary: Token vocabulary mapping token ID → string. +/// - lm: Optional ARPA language model for rescoring. +/// - beamWidth: Number of hypotheses to maintain (default 100). +/// - lmWeight: LM score scaling factor (alpha, default 0.3). +/// - wordBonus: Per-word bonus in nats (beta, default 0.0). +/// - blankId: CTC blank token index (default 1024). +/// - tokenCandidates: Number of top tokens to consider per frame (default 40). +/// - Returns: Decoded text string. +public func ctcBeamSearch( + logProbs: [[Float]], + vocabulary: [Int: String], + lm: ARPALanguageModel? = nil, + beamWidth: Int = 100, + lmWeight: Float = 0.3, + wordBonus: Float = 0.0, + blankId: Int = 1024, + tokenCandidates: Int = 40 +) -> String { + guard !logProbs.isEmpty else { return "" } + let vocabSize = logProbs[0].count + guard vocabSize > 0 else { return "" } + + var beams: [[Int]: CtcBeam] = [ + []: CtcBeam( + prefix: [], pBlank: 0.0, pNonBlank: -.infinity, + lmScore: 0.0, wordPieces: [], prevWord: nil) + ] + + for frame in logProbs { + let blankLp = blankId < frame.count ? frame[blankId] : -.infinity + + // Find top token candidates (excluding blank) + let topTokens = (0.. frame[$1] } + .prefix(tokenCandidates) + + var newBeams: [[Int]: CtcBeam] = [:] + + func merge(_ beam: CtcBeam) { + let k = beam.prefix + if var existing = newBeams[k] { + existing.pBlank = logAddExp(existing.pBlank, beam.pBlank) + existing.pNonBlank = logAddExp(existing.pNonBlank, beam.pNonBlank) + newBeams[k] = existing + } else { + newBeams[k] = beam + } + } + + for (_, beam) in beams { + let prevTotal = beam.totalAcoustic + + // Blank extension + var blankBeam = beam + blankBeam.pBlank = prevTotal + blankLp + blankBeam.pNonBlank = -.infinity + merge(blankBeam) + + for v in topTokens { + let tokenLp = frame[v] + let isRepeat = (beam.lastToken == v) + let piece = vocabulary[v] ?? "" + + // Word tracking + LM delta + var newWordPieces = beam.wordPieces + var newPrevWord = beam.prevWord + var lmDelta: Float = 0.0 + if let lm = lm, piece.hasPrefix("▁") { + let completedWord = newWordPieces.joined() + let hasCompletedWord = !completedWord.isEmpty + lmDelta = + hasCompletedWord + ? lmWeight * lm.score(word: completedWord, prev: newPrevWord) + wordBonus + : 0.0 + newPrevWord = hasCompletedWord ? completedWord : newPrevWord + let stripped = String(piece.dropFirst()) + newWordPieces = stripped.isEmpty ? [] : [stripped] + } else if lm != nil { + newWordPieces.append(piece) + } + + if isRepeat { + // Corrected repeat handling + var sameBeam = beam + sameBeam.pBlank = -.infinity + sameBeam.pNonBlank = beam.pNonBlank + tokenLp + merge(sameBeam) + + merge( + CtcBeam( + prefix: beam.prefix + [v], + pBlank: -.infinity, + pNonBlank: beam.pBlank + tokenLp, + lmScore: beam.lmScore + lmDelta, + wordPieces: newWordPieces, + prevWord: newPrevWord + )) + } else { + merge( + CtcBeam( + prefix: beam.prefix + [v], + pBlank: -.infinity, + pNonBlank: prevTotal + tokenLp, + lmScore: beam.lmScore + lmDelta, + wordPieces: newWordPieces, + prevWord: newPrevWord + )) + } + } + } + + // Prune to beam width + let sorted = newBeams.values.sorted { $0.total > $1.total } + beams = [:] + for beam in sorted.prefix(beamWidth) { + beams[beam.prefix] = beam + } + } + + // Finalize: score trailing partial word + let finalBeams = beams.values.map { beam -> CtcBeam in + guard let lm = lm else { return beam } + let lastWord = beam.wordPieces.joined() + guard !lastWord.isEmpty else { return beam } + var b = beam + b.lmScore += lmWeight * lm.score(word: lastWord, prev: beam.prevWord) + wordBonus + return b + } + guard let best = finalBeams.max(by: { $0.total < $1.total }) else { return "" } + return decodeCtcTokenIds(best.prefix, vocabulary: vocabulary) +} + +/// CTC beam search from an MLMultiArray of shape [1, T, V]. +/// +/// Convenience overload that extracts per-frame log-probabilities from the MLMultiArray +/// and delegates to the `[[Float]]` version. +public func ctcBeamSearch( + logProbs: MLMultiArray, + vocabulary: [Int: String], + lm: ARPALanguageModel? = nil, + beamWidth: Int = 100, + lmWeight: Float = 0.3, + wordBonus: Float = 0.0, + blankId: Int = 1024, + tokenCandidates: Int = 40 +) -> String { + let timeSteps = logProbs.shape[1].intValue + let vocabSize = logProbs.shape[2].intValue + let ptr = logProbs.dataPointer.assumingMemoryBound(to: Float32.self) + + var frames: [[Float]] = [] + frames.reserveCapacity(timeSteps) + for t in 0.. Float { + if a == -.infinity { return b } + if b == -.infinity { return a } + let m = max(a, b) + return m + log(exp(a - m) + exp(b - m)) +} + +/// Decode a sequence of token IDs using a vocabulary mapping. +/// +/// Replaces SentencePiece `▁` markers with spaces and trims whitespace. +public func decodeCtcTokenIds(_ ids: [Int], vocabulary: [Int: String]) -> String { + ids.compactMap { vocabulary[$0] } + .joined() + .replacingOccurrences(of: "▁", with: " ") + .trimmingCharacters(in: .whitespaces) +} diff --git a/Tests/FluidAudioTests/ASR/CTC/ARPALanguageModelTests.swift b/Tests/FluidAudioTests/ASR/CTC/ARPALanguageModelTests.swift new file mode 100644 index 000000000..b72e0500c --- /dev/null +++ b/Tests/FluidAudioTests/ASR/CTC/ARPALanguageModelTests.swift @@ -0,0 +1,174 @@ +import Foundation +import XCTest + +@testable import FluidAudio + +final class ARPALanguageModelTests: XCTestCase { + + // MARK: - Helpers + + /// Create a temporary ARPA file with the given content and return its URL. + private func writeTemporaryARPA(_ content: String) throws -> URL { + let tempDir = FileManager.default.temporaryDirectory + let url = tempDir.appendingPathComponent("test_\(UUID().uuidString).arpa") + try content.write(to: url, atomically: true, encoding: .utf8) + addTeardownBlock { try? FileManager.default.removeItem(at: url) } + return url + } + + private let sampleARPA = """ + \\data\\ + ngram 1=4 + ngram 2=2 + + \\1-grams: + -1.0\tthe\t-0.5 + -1.2\tcat\t-0.3 + -1.5\tsat\t0.0 + -2.0\t\t0.0 + + \\2-grams: + -0.5\tthe\tcat + -0.8\tcat\tsat + + \\end\\ + """ + + // MARK: - Loading + + func testLoadARPAFile() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + XCTAssertEqual(lm.unigrams.count, 4, "Should load 4 unigrams") + XCTAssertEqual(lm.bigrams.count, 2, "Should load 2 bigram contexts") + } + + func testLoadARPAUnigramValues() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + let theEntry = lm.unigrams["the"] + XCTAssertNotNil(theEntry) + + // log10(-1.0) * 2.302585 ≈ -2.302585 + let expectedLogProb = -1.0 * ARPALanguageModel.log10ToNat + XCTAssertEqual(theEntry!.logProb, expectedLogProb, accuracy: 0.001) + + // backoff: log10(-0.5) * 2.302585 ≈ -1.151293 + let expectedBackoff = -0.5 * ARPALanguageModel.log10ToNat + XCTAssertEqual(theEntry!.backoff, expectedBackoff, accuracy: 0.001) + } + + func testLoadARPABigrams() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + XCTAssertNotNil(lm.bigrams["the"]?["cat"]) + XCTAssertNotNil(lm.bigrams["cat"]?["sat"]) + XCTAssertNil(lm.bigrams["sat"]?["the"]) + } + + func testLoadNonexistentFileThrows() { + let bogusURL = URL(fileURLWithPath: "/nonexistent/path/to/model.arpa") + XCTAssertThrowsError(try ARPALanguageModel.load(from: bogusURL)) + } + + func testLoadEmptyARPA() throws { + let content = """ + \\data\\ + ngram 1=0 + + \\1-grams: + + \\end\\ + """ + let url = try writeTemporaryARPA(content) + let lm = try ARPALanguageModel.load(from: url) + + XCTAssertTrue(lm.unigrams.isEmpty) + XCTAssertTrue(lm.bigrams.isEmpty) + } + + // MARK: - Scoring + + func testScoreBigramAvailable() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + let score = lm.score(word: "cat", prev: "the") + let expected = -0.5 * ARPALanguageModel.log10ToNat + XCTAssertEqual(score, expected, accuracy: 0.001) + } + + func testScoreFallsBackToUnigram() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + // "sat" given "the" — no bigram exists, falls back to unigram("sat") + backoff("the") + let score = lm.score(word: "sat", prev: "the") + let unigramLogProb = -1.5 * ARPALanguageModel.log10ToNat + let backoff = -0.5 * ARPALanguageModel.log10ToNat + let expected = backoff + unigramLogProb + XCTAssertEqual(score, expected, accuracy: 0.001) + } + + func testScoreNoPrevContextSkipsBackoff() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + let score = lm.score(word: "cat", prev: nil) + let expected = -1.2 * ARPALanguageModel.log10ToNat + XCTAssertEqual(score, expected, accuracy: 0.001) + } + + func testScoreOOVWordGetsUnkPenalty() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + let score = lm.score(word: "xyzzy", prev: nil) + XCTAssertEqual(score, ARPALanguageModel.unkLogProb, accuracy: 0.001) + } + + func testScoreOOVWithBackoff() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + // OOV word with prev context "the" — backoff("the") + unkLogProb + let score = lm.score(word: "xyzzy", prev: "the") + let backoff = -0.5 * ARPALanguageModel.log10ToNat + let expected = backoff + ARPALanguageModel.unkLogProb + XCTAssertEqual(score, expected, accuracy: 0.001) + } + + // MARK: - Beam Search with LM + + func testBeamSearchWithLMInfluencesResult() throws { + let url = try writeTemporaryARPA(sampleARPA) + let lm = try ARPALanguageModel.load(from: url) + + // Vocabulary with word-start markers matching ARPA words + let vocab: [Int: String] = [0: "▁the", 1: "▁cat", 2: "▁dog"] + let blankId = 3 + + // Make "dog" slightly better acoustically than "cat" + let logProbs: [[Float]] = [ + [0.0, -100.0, -100.0, -100.0], // "the" clearly best + [-100.0, -1.0, -0.9, -100.0], // "cat" vs "dog" — dog slightly better acoustically + ] + + // Without LM — should pick "dog" (better acoustic score) + let noLM = ctcBeamSearch( + logProbs: logProbs, vocabulary: vocab, lm: nil, + beamWidth: 10, lmWeight: 0.0, blankId: blankId + ) + XCTAssertEqual(noLM, "the dog") + + // With strong LM — "the cat" has a bigram entry, "the dog" doesn't + let withLM = ctcBeamSearch( + logProbs: logProbs, vocabulary: vocab, lm: lm, + beamWidth: 10, lmWeight: 5.0, blankId: blankId + ) + XCTAssertEqual(withLM, "the cat") + } +} diff --git a/Tests/FluidAudioTests/ASR/CTC/CtcDecoderTests.swift b/Tests/FluidAudioTests/ASR/CTC/CtcDecoderTests.swift new file mode 100644 index 000000000..f77df46a1 --- /dev/null +++ b/Tests/FluidAudioTests/ASR/CTC/CtcDecoderTests.swift @@ -0,0 +1,258 @@ +import CoreML +import Foundation +import XCTest + +@testable import FluidAudio + +final class CtcDecoderTests: XCTestCase { + + // MARK: - logAddExp + + func testLogAddExpEqualValues() { + let result = logAddExp(0.0, 0.0) + XCTAssertEqual(result, Float(log(2.0)), accuracy: 0.001) + } + + func testLogAddExpWithNegInfinity() { + XCTAssertEqual(logAddExp(-.infinity, 5.0), 5.0) + XCTAssertEqual(logAddExp(5.0, -.infinity), 5.0) + } + + func testLogAddExpBothNegInfinity() { + XCTAssertEqual(logAddExp(-.infinity, -.infinity), -.infinity) + } + + func testLogAddExpLargeDifference() { + let result = logAddExp(100.0, 0.0) + XCTAssertEqual(result, 100.0, accuracy: 0.001) + } + + func testLogAddExpIsCommutative() { + let a = logAddExp(1.0, 2.0) + let b = logAddExp(2.0, 1.0) + XCTAssertEqual(a, b, accuracy: 1e-6) + } + + // MARK: - decodeCtcTokenIds + + func testDecodeTokenIdsBasic() { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let result = decodeCtcTokenIds([0, 1], vocabulary: vocab) + XCTAssertEqual(result, "hello world") + } + + func testDecodeTokenIdsEmpty() { + let vocab: [Int: String] = [0: "▁hello"] + let result = decodeCtcTokenIds([], vocabulary: vocab) + XCTAssertEqual(result, "") + } + + func testDecodeTokenIdsSkipsUnknown() { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let result = decodeCtcTokenIds([0, 999, 1], vocabulary: vocab) + XCTAssertEqual(result, "hello world") + } + + func testDecodeTokenIdsSubwordJoin() { + let vocab: [Int: String] = [0: "he", 1: "llo", 2: "▁world"] + let result = decodeCtcTokenIds([0, 1, 2], vocabulary: vocab) + XCTAssertEqual(result, "hello world") + } + + // MARK: - CTC Greedy Decode ([[Float]]) + + func testGreedyDecodeSimple() { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let blankId = 2 + // Frame 0: token 0 dominant, Frame 1: blank, Frame 2: token 1 dominant + let logProbs: [[Float]] = [ + [0.0, -100.0, -100.0], + [-100.0, -100.0, 0.0], + [-100.0, 0.0, -100.0], + ] + let result = ctcGreedyDecode(logProbs: logProbs, vocabulary: vocab, blankId: blankId) + XCTAssertEqual(result, "hello world") + } + + func testGreedyDecodeCollapsesRepeats() { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let blankId = 2 + let logProbs: [[Float]] = [ + [0.0, -100.0, -100.0], + [0.0, -100.0, -100.0], + [-100.0, 0.0, -100.0], + ] + let result = ctcGreedyDecode(logProbs: logProbs, vocabulary: vocab, blankId: blankId) + XCTAssertEqual(result, "hello world") + } + + func testGreedyDecodeBlankAllowsRepeats() { + let vocab: [Int: String] = [0: "▁hello"] + let blankId = 1 + let logProbs: [[Float]] = [ + [0.0, -100.0], + [-100.0, 0.0], + [0.0, -100.0], + ] + let result = ctcGreedyDecode(logProbs: logProbs, vocabulary: vocab, blankId: blankId) + XCTAssertEqual(result, "hello hello") + } + + func testGreedyDecodeAllBlanks() { + let vocab: [Int: String] = [0: "▁hello"] + let blankId = 1 + let logProbs: [[Float]] = [ + [-100.0, 0.0], + [-100.0, 0.0], + [-100.0, 0.0], + ] + let result = ctcGreedyDecode(logProbs: logProbs, vocabulary: vocab, blankId: blankId) + XCTAssertEqual(result, "") + } + + func testGreedyDecodeEmptyInput() { + let vocab: [Int: String] = [0: "▁hello"] + let result = ctcGreedyDecode(logProbs: [], vocabulary: vocab, blankId: 1) + XCTAssertEqual(result, "") + } + + // MARK: - CTC Greedy Decode (MLMultiArray) + + func testGreedyDecodeMLMultiArray() throws { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let blankId = 2 + let vocabSize = 3 + let timeSteps = 3 + + let arr = try MLMultiArray( + shape: [1, NSNumber(value: timeSteps), NSNumber(value: vocabSize)], + dataType: .float32 + ) + let ptr = arr.dataPointer.assumingMemoryBound(to: Float32.self) + // Fill with -100 + for i in 0..<(timeSteps * vocabSize) { ptr[i] = -100.0 } + // Frame 0: token 0, Frame 1: blank, Frame 2: token 1 + ptr[0 * vocabSize + 0] = 0.0 + ptr[1 * vocabSize + blankId] = 0.0 + ptr[2 * vocabSize + 1] = 0.0 + + let result = ctcGreedyDecode(logProbs: arr, vocabulary: vocab, blankId: blankId) + XCTAssertEqual(result, "hello world") + } + + // MARK: - CTC Beam Search + + func testBeamSearchNoLMMatchesGreedy() { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let blankId = 2 + let logProbs: [[Float]] = [ + [0.0, -100.0, -100.0], + [-100.0, -100.0, 0.0], + [-100.0, 0.0, -100.0], + ] + let greedy = ctcGreedyDecode(logProbs: logProbs, vocabulary: vocab, blankId: blankId) + let beam = ctcBeamSearch( + logProbs: logProbs, vocabulary: vocab, lm: nil, + beamWidth: 5, lmWeight: 0.0, blankId: blankId + ) + XCTAssertEqual(greedy, beam) + } + + func testBeamSearchAllBlanks() { + let vocab: [Int: String] = [0: "▁hello"] + let blankId = 1 + let logProbs: [[Float]] = [ + [-100.0, 0.0], + [-100.0, 0.0], + [-100.0, 0.0], + ] + let result = ctcBeamSearch( + logProbs: logProbs, vocabulary: vocab, lm: nil, + beamWidth: 5, lmWeight: 0.0, blankId: blankId + ) + XCTAssertEqual(result, "") + } + + func testBeamSearchEmptyInput() { + let vocab: [Int: String] = [0: "▁hello"] + let result = ctcBeamSearch( + logProbs: [], vocabulary: vocab, lm: nil, + beamWidth: 5, lmWeight: 0.0, blankId: 1 + ) + XCTAssertEqual(result, "") + } + + func testBeamSearchSingleToken() { + let vocab: [Int: String] = [0: "▁hello"] + let blankId = 1 + let logProbs: [[Float]] = [ + [0.0, -100.0] + ] + let result = ctcBeamSearch( + logProbs: logProbs, vocabulary: vocab, lm: nil, + beamWidth: 5, lmWeight: 0.0, blankId: blankId + ) + XCTAssertEqual(result, "hello") + } + + // MARK: - CtcBeam + + func testCtcBeamTotalAcoustic() { + let beam = CtcBeam( + prefix: [1], pBlank: -1.0, pNonBlank: -2.0, + lmScore: 0.0, wordPieces: [], prevWord: nil + ) + let expected = logAddExp(-1.0, -2.0) + XCTAssertEqual(beam.totalAcoustic, expected, accuracy: 1e-6) + } + + func testCtcBeamTotalIncludesLM() { + let beam = CtcBeam( + prefix: [1], pBlank: 0.0, pNonBlank: -.infinity, + lmScore: 5.0, wordPieces: [], prevWord: nil + ) + XCTAssertEqual(beam.total, 5.0, accuracy: 1e-6) + } + + func testCtcBeamLastToken() { + let beam = CtcBeam( + prefix: [1, 2, 3], pBlank: 0.0, pNonBlank: 0.0, + lmScore: 0.0, wordPieces: [], prevWord: nil + ) + XCTAssertEqual(beam.lastToken, 3) + } + + func testCtcBeamLastTokenEmpty() { + let beam = CtcBeam( + prefix: [], pBlank: 0.0, pNonBlank: 0.0, + lmScore: 0.0, wordPieces: [], prevWord: nil + ) + XCTAssertNil(beam.lastToken) + } + + // MARK: - CTC Beam Search (MLMultiArray) + + func testBeamSearchMLMultiArrayMatchesGreedy() throws { + let vocab: [Int: String] = [0: "▁hello", 1: "▁world"] + let blankId = 2 + let vocabSize = 3 + let timeSteps = 3 + + let arr = try MLMultiArray( + shape: [1, NSNumber(value: timeSteps), NSNumber(value: vocabSize)], + dataType: .float32 + ) + let ptr = arr.dataPointer.assumingMemoryBound(to: Float32.self) + for i in 0..<(timeSteps * vocabSize) { ptr[i] = -100.0 } + ptr[0 * vocabSize + 0] = 0.0 + ptr[1 * vocabSize + blankId] = 0.0 + ptr[2 * vocabSize + 1] = 0.0 + + let greedy = ctcGreedyDecode(logProbs: arr, vocabulary: vocab, blankId: blankId) + let beam = ctcBeamSearch( + logProbs: arr, vocabulary: vocab, lm: nil, + beamWidth: 5, lmWeight: 0.0, blankId: blankId + ) + XCTAssertEqual(greedy, beam) + } +}