diff --git a/.github/workflows/kitten-tts-test.yml b/.github/workflows/kitten-tts-test.yml new file mode 100644 index 000000000..a47764ab5 --- /dev/null +++ b/.github/workflows/kitten-tts-test.yml @@ -0,0 +1,210 @@ +name: KittenTTS Smoke Test + +on: + pull_request: + branches: [main] + workflow_dispatch: + +jobs: + kitten-tts-smoke-test: + name: KittenTTS Smoke Test + runs-on: macos-15 + permissions: + contents: read + pull-requests: write + + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v5 + + - uses: swift-actions/setup-swift@v2 + with: + swift-version: "6.1" + + - name: Cache Dependencies + uses: actions/cache@v4 + with: + path: | + .build + ~/.cache/fluidaudio/Models/kokoro + ~/.cache/fluidaudio/Models/kittentts-coreml + ~/Library/Caches/Homebrew + key: ${{ runner.os }}-kitten-tts-${{ hashFiles('Package.resolved', 'Sources/FluidAudio/TTS/KittenTTS/**', 'Sources/FluidAudio/ModelNames.swift') }} + + - name: Build + run: swift build -c release + + - name: Run KittenTTS Nano Smoke Test + id: nano-test + run: | + echo "=========================================" + echo "KittenTTS Nano smoke test" + echo "=========================================" + echo "" + + TEXT="Hello world" + + if .build/release/fluidaudiocli tts "$TEXT" \ + --backend kitten-nano \ + --voice expr-voice-3-f \ + --output kitten_nano_output.wav 2>&1; then + echo "Nano smoke test PASSED" + echo "NANO_STATUS=PASSED" >> $GITHUB_OUTPUT + else + EXIT_CODE=$? + echo "Nano smoke test FAILED with exit code $EXIT_CODE" + echo "NANO_STATUS=FAILED" >> $GITHUB_OUTPUT + fi + + if [ -f kitten_nano_output.wav ]; then + SIZE=$(stat -f%z kitten_nano_output.wav 2>/dev/null || stat -c%s kitten_nano_output.wav 2>/dev/null) + echo "Nano output file size: $SIZE bytes" + echo "NANO_FILE_SIZE=$SIZE" >> $GITHUB_OUTPUT + else + echo "NANO_FILE_SIZE=0" >> $GITHUB_OUTPUT + fi + + - name: Run KittenTTS Mini Smoke Test + id: mini-test + run: | + echo "=========================================" + echo "KittenTTS Mini smoke test" + echo "=========================================" + echo "" + + TEXT="The quick brown fox jumps over the lazy dog." + + if .build/release/fluidaudiocli tts "$TEXT" \ + --backend kitten-mini \ + --voice expr-voice-3-f \ + --speed 1.0 \ + --output kitten_mini_output.wav 2>&1; then + echo "Mini smoke test PASSED" + echo "MINI_STATUS=PASSED" >> $GITHUB_OUTPUT + else + EXIT_CODE=$? + echo "Mini smoke test FAILED with exit code $EXIT_CODE" + echo "MINI_STATUS=FAILED" >> $GITHUB_OUTPUT + fi + + if [ -f kitten_mini_output.wav ]; then + SIZE=$(stat -f%z kitten_mini_output.wav 2>/dev/null || stat -c%s kitten_mini_output.wav 2>/dev/null) + echo "Mini output file size: $SIZE bytes" + echo "MINI_FILE_SIZE=$SIZE" >> $GITHUB_OUTPUT + else + echo "MINI_FILE_SIZE=0" >> $GITHUB_OUTPUT + fi + + - name: Verify Lexicon Cache Downloaded + id: lexicon-check + run: | + LEXICON_PATH="$HOME/.cache/fluidaudio/Models/kokoro/us_lexicon_cache.json" + if [ -f "$LEXICON_PATH" ]; then + SIZE=$(stat -f%z "$LEXICON_PATH" 2>/dev/null || stat -c%s "$LEXICON_PATH" 2>/dev/null) + echo "✅ Lexicon cache downloaded: $SIZE bytes" + echo "LEXICON_STATUS=DOWNLOADED" >> $GITHUB_OUTPUT + echo "LEXICON_SIZE=$SIZE" >> $GITHUB_OUTPUT + else + echo "❌ Lexicon cache NOT found at $LEXICON_PATH" + echo "LEXICON_STATUS=MISSING" >> $GITHUB_OUTPUT + echo "LEXICON_SIZE=0" >> $GITHUB_OUTPUT + fi + + - name: Comment PR + if: github.event_name == 'pull_request' + continue-on-error: true + uses: actions/github-script@v7 + with: + script: | + const nanoStatus = '${{ steps.nano-test.outputs.NANO_STATUS }}'; + const miniStatus = '${{ steps.mini-test.outputs.MINI_STATUS }}'; + const lexiconStatus = '${{ steps.lexicon-check.outputs.LEXICON_STATUS }}'; + + const nanoEmoji = nanoStatus === 'PASSED' ? '✅' : '❌'; + const miniEmoji = miniStatus === 'PASSED' ? '✅' : '❌'; + const lexiconEmoji = lexiconStatus === 'DOWNLOADED' ? '✅' : '❌'; + + const nanoFileSize = '${{ steps.nano-test.outputs.NANO_FILE_SIZE }}'; + const miniFileSize = '${{ steps.mini-test.outputs.MINI_FILE_SIZE }}'; + const lexiconSize = '${{ steps.lexicon-check.outputs.LEXICON_SIZE }}'; + + const nanoSizeKB = (parseInt(nanoFileSize) / 1024).toFixed(1); + const miniSizeKB = (parseInt(miniFileSize) / 1024).toFixed(1); + const lexiconSizeMB = (parseInt(lexiconSize) / 1024 / 1024).toFixed(1); + + const body = `## KittenTTS Smoke Test + + ### Test Results + + | Variant | Status | Output Size | + |---------|--------|-------------| + | **Nano** (15M) | ${nanoEmoji} | ${parseInt(nanoFileSize) > 0 ? nanoSizeKB + ' KB' : 'N/A'} | + | **Mini** (82M) | ${miniEmoji} | ${parseInt(miniFileSize) > 0 ? miniSizeKB + ' KB' : 'N/A'} | + + ### Dependencies + + | Component | Status | Size | + |-----------|--------|------| + | Build | ✅ | - | + | Lexicon cache (us_lexicon_cache.json) | ${lexiconEmoji} | ${parseInt(lexiconSize) > 0 ? lexiconSizeMB + ' MB' : 'N/A'} | + | Kokoro G2P pipeline | ${nanoStatus === 'PASSED' || miniStatus === 'PASSED' ? '✅' : '❌'} | - | + + **Note:** KittenTTS reuses Kokoro's G2P pipeline for phonemization. This test verifies the lexicon cache auto-downloads correctly and both Nano/Mini variants can synthesize audio. + + `; + + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }); + + const existing = comments.find(c => + c.body.includes('') + ); + + if (existing) { + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: existing.id, + body: body + }); + } else { + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: body + }); + } + + - name: Upload Nano Output + if: always() + uses: actions/upload-artifact@v4 + with: + name: kitten-nano-output + path: kitten_nano_output.wav + retention-days: 7 + + - name: Upload Mini Output + if: always() + uses: actions/upload-artifact@v4 + with: + name: kitten-mini-output + path: kitten_mini_output.wav + retention-days: 7 + + - name: Fail if Tests Failed + run: | + NANO_STATUS="${{ steps.nano-test.outputs.NANO_STATUS }}" + MINI_STATUS="${{ steps.mini-test.outputs.MINI_STATUS }}" + LEXICON_STATUS="${{ steps.lexicon-check.outputs.LEXICON_STATUS }}" + + if [ "$NANO_STATUS" != "PASSED" ] || [ "$MINI_STATUS" != "PASSED" ] || [ "$LEXICON_STATUS" != "DOWNLOADED" ]; then + echo "❌ One or more tests failed" + exit 1 + fi + + echo "✅ All tests passed" diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index 05160cbf7..fdf60aa22 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -17,6 +17,8 @@ public enum Repo: String, CaseIterable { case pocketTts = "FluidInference/pocket-tts-coreml" case qwen3Asr = "FluidInference/qwen3-asr-0.6b-coreml/f32" case qwen3AsrInt8 = "FluidInference/qwen3-asr-0.6b-coreml/int8" + case kittenTtsNano = "alexwengg/kittentts-coreml/nano" + case kittenTtsMini = "alexwengg/kittentts-coreml/mini" /// Repository slug (without owner) public var name: String { @@ -51,6 +53,10 @@ public enum Repo: String, CaseIterable { return "qwen3-asr-0.6b-coreml/f32" case .qwen3AsrInt8: return "qwen3-asr-0.6b-coreml/int8" + case .kittenTtsNano: + return "kittentts-coreml/nano" + case .kittenTtsMini: + return "kittentts-coreml/mini" } } @@ -69,6 +75,8 @@ public enum Repo: String, CaseIterable { return "FluidInference/ls-eend-coreml" case .qwen3Asr, .qwen3AsrInt8: return "FluidInference/qwen3-asr-0.6b-coreml" + case .kittenTtsNano, .kittenTtsMini: + return "alexwengg/kittentts-coreml" default: return "FluidInference/\(name)" } @@ -87,6 +95,10 @@ public enum Repo: String, CaseIterable { return "f32" case .qwen3AsrInt8: return "int8" + case .kittenTtsNano: + return "nano" + case .kittenTtsMini: + return "mini" default: return nil } @@ -109,6 +121,10 @@ public enum Repo: String, CaseIterable { return "ls-eend" case .pocketTts: return "pocket-tts" + case .kittenTtsNano: + return "kittentts-coreml/nano" + case .kittenTtsMini: + return "kittentts-coreml/mini" default: return name } @@ -454,6 +470,77 @@ public enum ModelNames { ] } + /// KittenTTS model names (Nano 15M / Mini 80M StyleTTS2-based TTS) + public enum KittenTTS { + + /// KittenTTS model duration variants. + public enum Variant: CaseIterable, Sendable { + /// 5-second model (70 max tokens). + case fiveSecond + /// 10-second model (140 max tokens). + case tenSecond + + /// Nano model bundle filename for this variant. + public func nanoFileName() -> String { + switch self { + case .fiveSecond: + return "kittentts_5s.mlmodelc" + case .tenSecond: + return "kittentts_10s.mlmodelc" + } + } + + /// Mini model bundle filename for this variant. + public func miniFileName() -> String { + switch self { + case .fiveSecond: + return "kittentts_mini_5s.mlmodelc" + case .tenSecond: + return "kittentts_mini_10s.mlmodelc" + } + } + + /// Maximum number of phoneme tokens for this variant. + public var maxTokens: Int { + switch self { + case .fiveSecond: + return 70 + case .tenSecond: + return 140 + } + } + } + + /// Preferred variant for general-purpose synthesis. + public static let defaultVariant: Variant = .tenSecond + + /// Voice embeddings directory name. + public static let voicesDir = "voices" + + /// Available voice identifiers. + public static let availableVoices: [String] = [ + "expr-voice-2-m", "expr-voice-2-f", + "expr-voice-3-m", "expr-voice-3-f", + "expr-voice-4-m", "expr-voice-4-f", + "expr-voice-5-m", "expr-voice-5-f", + ] + + /// Default voice for synthesis. + public static let defaultVoice = "expr-voice-3-f" + + /// All Nano model bundles required by the downloader. + public static var nanoRequiredModels: Set { + Set(Variant.allCases.map { $0.nanoFileName() }) + .union([voicesDir]) + } + + /// All Mini model bundles required by the downloader. + public static var miniRequiredModels: Set { + Set(Variant.allCases.map { $0.miniFileName() }) + .union([voicesDir]) + } + } + /// TTS model names public enum TTS { @@ -540,6 +627,10 @@ public enum ModelNames { return ModelNames.LSEEND.requiredModels case .qwen3Asr, .qwen3AsrInt8: return ModelNames.Qwen3ASR.requiredModelsFull + case .kittenTtsNano: + return ModelNames.KittenTTS.nanoRequiredModels + case .kittenTtsMini: + return ModelNames.KittenTTS.miniRequiredModels } } } diff --git a/Sources/FluidAudio/TTS/KittenTTS/KittenTTSError.swift b/Sources/FluidAudio/TTS/KittenTTS/KittenTTSError.swift new file mode 100644 index 000000000..1e1c0785d --- /dev/null +++ b/Sources/FluidAudio/TTS/KittenTTS/KittenTTSError.swift @@ -0,0 +1,22 @@ +import Foundation + +/// Errors that can occur during KittenTTS synthesis. +public enum KittenTTSError: LocalizedError { + case downloadFailed(String) + case corruptedModel(String) + case modelNotFound(String) + case processingFailed(String) + + public var errorDescription: String? { + switch self { + case .downloadFailed(let message): + return "Download failed: \(message)" + case .corruptedModel(let name): + return "Model \(name) is corrupted" + case .modelNotFound(let name): + return "Model \(name) not found" + case .processingFailed(let message): + return "Processing failed: \(message)" + } + } +} diff --git a/Sources/FluidAudio/TTS/KittenTTS/KittenTtsConstants.swift b/Sources/FluidAudio/TTS/KittenTTS/KittenTtsConstants.swift new file mode 100644 index 000000000..f40d14923 --- /dev/null +++ b/Sources/FluidAudio/TTS/KittenTTS/KittenTtsConstants.swift @@ -0,0 +1,62 @@ +import Foundation + +/// Constants for the KittenTTS StyleTTS2-based TTS backend. +public enum KittenTtsConstants { + + // MARK: - Audio + + /// Output sample rate in Hz. + public static let audioSampleRate: Int = 24_000 + + // MARK: - Vocabulary + + /// The 178-token IPA vocabulary as Unicode scalars. + /// Index 0 (`$`) is the BOS/EOS/padding token. + /// Each scalar's position in this array is its token ID. + /// + /// Note: stored as `[Unicode.Scalar]` rather than `String` because + /// U+0329 (COMBINING VERTICAL LINE BELOW) at index 175 merges with + /// the preceding U+2018 into a single Swift `Character`, making + /// `String.count` return 177 instead of 178. + // swiftlint:disable:next line_length + public static let vocabScalars: [Unicode.Scalar] = Array( + "$;:,.!?¡¿—…\"«»\u{201C}\u{201D} ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘\u{2018}\u{0329}\u{2019}ᵻ" + .unicodeScalars) + + /// Vocabulary size (number of tokens including padding). + public static let vocabSize: Int = 178 + + /// BOS/EOS/padding token ID. + public static let padTokenId: Int32 = 0 + + // MARK: - Model dimensions + + /// Nano voice embedding dimension (single 256-float vector per voice). + public static let nanoVoiceDim: Int = 256 + + /// Mini voice matrix rows (one row per token count, 0-399). + public static let miniVoiceRows: Int = 400 + + /// Mini voice embedding dimension per row. + public static let miniVoiceDim: Int = 256 + + /// Number of harmonic channels for Nano source noise and random phases. + public static let nanoHarmonics: Int = 9 + + // MARK: - Nano model sizes + + /// Maximum audio samples for 5-second Nano model. + public static let nano5sMaxSamples: Int = 120_000 + + /// Maximum audio samples for 10-second Nano model. + public static let nano10sMaxSamples: Int = 240_000 + + // MARK: - Voices + + /// Default voice identifier. + public static let defaultVoice: String = "expr-voice-3-f" + + // MARK: - Repository + + public static let defaultModelsSubdirectory: String = "Models" +} diff --git a/Sources/FluidAudio/TTS/KittenTTS/KittenTtsManager.swift b/Sources/FluidAudio/TTS/KittenTTS/KittenTtsManager.swift new file mode 100644 index 000000000..52da6af52 --- /dev/null +++ b/Sources/FluidAudio/TTS/KittenTTS/KittenTtsManager.swift @@ -0,0 +1,136 @@ +import Foundation +import OSLog + +/// Manages text-to-speech synthesis using KittenTTS CoreML models. +/// +/// KittenTTS is a single-shot StyleTTS2-based synthesizer that produces +/// complete utterances in one forward pass at 24kHz. Two variants are available: +/// - **Nano** (15M params): Lightweight, no speed control +/// - **Mini** (80M params): Higher quality, speed control +/// +/// Example usage: +/// ```swift +/// let manager = KittenTtsManager(variant: .mini) +/// try await manager.initialize() +/// let audioData = try await manager.synthesize(text: "Hello, world!") +/// ``` +public actor KittenTtsManager { + + private let logger = AppLogger(category: "KittenTtsManager") + private let modelStore: KittenTtsModelStore + private var defaultVoice: String + private var isInitialized = false + + /// Creates a new KittenTTS manager. + /// + /// - Parameters: + /// - variant: Model variant to use (.nano or .mini). + /// - defaultVoice: Default voice identifier. + /// - directory: Optional override for the base cache directory. + /// When `nil`, uses the default platform cache location. + public init( + variant: KittenTtsVariant = .mini, + defaultVoice: String = KittenTtsConstants.defaultVoice, + directory: URL? = nil + ) { + self.modelStore = KittenTtsModelStore(variant: variant, directory: directory) + self.defaultVoice = defaultVoice + } + + public var isAvailable: Bool { + isInitialized + } + + /// Initialize by downloading and loading KittenTTS models. + public func initialize() async throws { + try await modelStore.loadIfNeeded() + isInitialized = true + logger.notice("KittenTtsManager initialized") + } + + /// Synthesize text to WAV audio data. + /// + /// - Parameters: + /// - text: The text to synthesize. + /// - voice: Voice identifier (default: uses the manager's default voice). + /// - speed: Speech speed multiplier (Mini only, 1.0 = normal). + /// - deEss: Whether to apply de-essing post-processing (default: true). + /// - Returns: WAV audio data at 24kHz. + public func synthesize( + text: String, + voice: String? = nil, + speed: Float = 1.0, + deEss: Bool = true + ) async throws -> Data { + guard isInitialized else { + throw KittenTTSError.modelNotFound("KittenTTS model not initialized") + } + + let selectedVoice = voice ?? defaultVoice + + return try await KittenTtsSynthesizer.withModelStore(modelStore) { + let result = try await KittenTtsSynthesizer.synthesize( + text: text, + voice: selectedVoice, + speed: speed, + deEss: deEss + ) + return result.audio + } + } + + /// Synthesize text and return detailed results. + public func synthesizeDetailed( + text: String, + voice: String? = nil, + speed: Float = 1.0, + deEss: Bool = true + ) async throws -> KittenTtsSynthesizer.SynthesisResult { + guard isInitialized else { + throw KittenTTSError.modelNotFound("KittenTTS model not initialized") + } + + let selectedVoice = voice ?? defaultVoice + + return try await KittenTtsSynthesizer.withModelStore(modelStore) { + try await KittenTtsSynthesizer.synthesize( + text: text, + voice: selectedVoice, + speed: speed, + deEss: deEss + ) + } + } + + /// Synthesize text and write the result directly to a file. + public func synthesizeToFile( + text: String, + outputURL: URL, + voice: String? = nil, + speed: Float = 1.0, + deEss: Bool = true + ) async throws { + if FileManager.default.fileExists(atPath: outputURL.path) { + try FileManager.default.removeItem(at: outputURL) + } + + let audioData = try await synthesize( + text: text, + voice: voice, + speed: speed, + deEss: deEss + ) + + try audioData.write(to: outputURL) + logger.notice("Saved synthesized audio to: \(outputURL.lastPathComponent)") + } + + /// Update the default voice. + public func setDefaultVoice(_ voice: String) { + defaultVoice = voice + } + + public func cleanup() { + isInitialized = false + } +} diff --git a/Sources/FluidAudio/TTS/KittenTTS/Pipeline/KittenTtsModelStore.swift b/Sources/FluidAudio/TTS/KittenTTS/Pipeline/KittenTtsModelStore.swift new file mode 100644 index 000000000..6cc175213 --- /dev/null +++ b/Sources/FluidAudio/TTS/KittenTTS/Pipeline/KittenTtsModelStore.swift @@ -0,0 +1,187 @@ +@preconcurrency import CoreML +import Foundation +import OSLog + +/// Actor-based store for KittenTTS CoreML models and voice embeddings. +/// +/// Manages loading and caching of the CoreML model (5s or 10s variant) +/// and the voice embedding data (binary float32 files). +public actor KittenTtsModelStore { + + private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "KittenTtsModelStore") + + private let kittenVariant: KittenTtsVariant + private var model5s: MLModel? + private var model10s: MLModel? + private var voiceCache: [String: [Float]] = [:] + private var repoDirectory: URL? + private let directory: URL? + + /// - Parameters: + /// - variant: Which KittenTTS variant to use (nano or mini). + /// - directory: Optional override for the base cache directory. + /// When `nil`, uses the default platform cache location. + public init(variant: KittenTtsVariant, directory: URL? = nil) { + self.kittenVariant = variant + self.directory = directory + } + + /// The KittenTTS variant this store manages. + public var variant: KittenTtsVariant { kittenVariant } + + /// Load all models and voices from cache, downloading if needed. + public func loadIfNeeded() async throws { + guard model10s == nil else { return } + + let targetDir = try directory ?? cacheDirectory() + let modelsDirectory = targetDir.appendingPathComponent( + KittenTtsConstants.defaultModelsSubdirectory) + + let repo: Repo = kittenVariant == .nano ? .kittenTtsNano : .kittenTtsMini + let repoDir = modelsDirectory.appendingPathComponent(repo.folderName) + + let requiredModels = ModelNames.getRequiredModelNames(for: repo, variant: nil) + let allPresent = requiredModels.allSatisfy { model in + FileManager.default.fileExists( + atPath: repoDir.appendingPathComponent(model).path) + } + + if !allPresent { + logger.info("Downloading KittenTTS \(self.kittenVariant.rawValue) models from HuggingFace...") + try await DownloadUtils.downloadRepo(repo, to: modelsDirectory) + } else { + logger.info("KittenTTS \(self.kittenVariant.rawValue) models found in cache") + } + + self.repoDirectory = repoDir + + // Use CPU+GPU to maintain float32 precision (avoid ANE float16 artifacts). + let config = MLModelConfiguration() + config.computeUnits = .cpuAndGPU + + let loadStart = Date() + + // Load both 5s and 10s models + let variant5s = ModelNames.KittenTTS.Variant.fiveSecond + let variant10s = ModelNames.KittenTTS.Variant.tenSecond + let fileName5s = + kittenVariant == .nano ? variant5s.nanoFileName() : variant5s.miniFileName() + let fileName10s = + kittenVariant == .nano ? variant10s.nanoFileName() : variant10s.miniFileName() + + let modelURL5s = repoDir.appendingPathComponent(fileName5s) + let modelURL10s = repoDir.appendingPathComponent(fileName10s) + + model5s = try MLModel(contentsOf: modelURL5s, configuration: config) + logger.info("Loaded \(fileName5s)") + + model10s = try MLModel(contentsOf: modelURL10s, configuration: config) + logger.info("Loaded \(fileName10s)") + + let elapsed = Date().timeIntervalSince(loadStart) + logger.info( + "KittenTTS \(self.kittenVariant.rawValue) models loaded in \(String(format: "%.2f", elapsed))s" + ) + } + + /// Get the 5-second model. + public func fiveSecondModel() throws -> MLModel { + guard let model = model5s else { + throw KittenTTSError.modelNotFound("KittenTTS 5s model not loaded") + } + return model + } + + /// Get the 10-second model. + public func tenSecondModel() throws -> MLModel { + guard let model = model10s else { + throw KittenTTSError.modelNotFound("KittenTTS 10s model not loaded") + } + return model + } + + /// Select the appropriate model based on token count. + public func model(for tokenCount: Int) throws -> (MLModel, ModelNames.KittenTTS.Variant) { + let variant: ModelNames.KittenTTS.Variant = + tokenCount <= ModelNames.KittenTTS.Variant.fiveSecond.maxTokens + ? .fiveSecond : .tenSecond + let model = + variant == .fiveSecond + ? try fiveSecondModel() + : try tenSecondModel() + return (model, variant) + } + + /// Load and cache voice embedding data for the given voice name. + public func voiceData(for voice: String) throws -> [Float] { + if let cached = voiceCache[voice] { + return cached + } + guard let repoDir = repoDirectory else { + throw KittenTTSError.modelNotFound("KittenTTS repository not loaded") + } + + let voicesDir = repoDir.appendingPathComponent(ModelNames.KittenTTS.voicesDir) + let voiceFile = voicesDir.appendingPathComponent("\(voice).bin") + + guard FileManager.default.fileExists(atPath: voiceFile.path) else { + throw KittenTTSError.modelNotFound( + "Voice '\(voice)' not found at \(voiceFile.path)") + } + + let data = try Data(contentsOf: voiceFile) + + let expectedSize: Int + if kittenVariant == .nano { + // Nano: 256 floats = 1024 bytes + expectedSize = KittenTtsConstants.nanoVoiceDim * MemoryLayout.size + } else { + // Mini: 400 × 256 floats = 409600 bytes + expectedSize = + KittenTtsConstants.miniVoiceRows * KittenTtsConstants.miniVoiceDim + * MemoryLayout.size + } + + guard data.count == expectedSize else { + throw KittenTTSError.corruptedModel( + "Voice '\(voice)' has unexpected size \(data.count) bytes (expected \(expectedSize))" + ) + } + + let floatCount = data.count / MemoryLayout.size + var floats = [Float](repeating: 0, count: floatCount) + _ = floats.withUnsafeMutableBytes { buffer in + data.copyBytes(to: buffer) + } + + voiceCache[voice] = floats + logger.info("Loaded voice '\(voice)' (\(floatCount) floats)") + return floats + } + + // MARK: - Private + + private func cacheDirectory() throws -> URL { + let baseDirectory: URL + #if os(macOS) + baseDirectory = FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent(".cache") + #else + guard + let first = FileManager.default.urls( + for: .cachesDirectory, in: .userDomainMask + ).first + else { + throw KittenTTSError.processingFailed("Failed to locate caches directory") + } + baseDirectory = first + #endif + + let cacheDirectory = baseDirectory.appendingPathComponent("fluidaudio") + if !FileManager.default.fileExists(atPath: cacheDirectory.path) { + try FileManager.default.createDirectory( + at: cacheDirectory, withIntermediateDirectories: true) + } + return cacheDirectory + } +} diff --git a/Sources/FluidAudio/TTS/KittenTTS/Pipeline/KittenTtsSynthesizer.swift b/Sources/FluidAudio/TTS/KittenTTS/Pipeline/KittenTtsSynthesizer.swift new file mode 100644 index 000000000..f44290b1a --- /dev/null +++ b/Sources/FluidAudio/TTS/KittenTTS/Pipeline/KittenTtsSynthesizer.swift @@ -0,0 +1,375 @@ +@preconcurrency import CoreML +import Foundation +import OSLog + +/// KittenTTS single-shot synthesizer. +/// +/// Handles phonemization (via Kokoro's G2P pipeline), tokenization to +/// KittenTTS vocab indices, CoreML inference, and audio extraction. +/// +/// Pipeline: text → phonemes (Kokoro G2P) → KittenTTS tokens → CoreML → audio → WAV +public struct KittenTtsSynthesizer { + + static let logger = AppLogger(category: "KittenTtsSynthesizer") + + private enum Context { + @TaskLocal static var modelStore: KittenTtsModelStore? + } + + static func withModelStore( + _ store: KittenTtsModelStore, + operation: () async throws -> T + ) async rethrows -> T { + try await Context.$modelStore.withValue(store) { + try await operation() + } + } + + static func currentModelStore() throws -> KittenTtsModelStore { + guard let store = Context.modelStore else { + throw KittenTTSError.processingFailed( + "KittenTtsSynthesizer requires a model store context.") + } + return store + } + + // MARK: - Public Result Type + + /// Result of a KittenTTS synthesis operation. + public struct SynthesisResult: Sendable { + /// WAV audio data at 24kHz. + public let audio: Data + /// Raw Float32 audio samples. + public let samples: [Float] + /// Number of valid audio samples. + public let sampleCount: Int + } + + // MARK: - Vocabulary Mapping + + /// Pre-built scalar-to-index map for the KittenTTS vocabulary. + /// + /// Uses `Unicode.Scalar` keys rather than `Character` because the + /// vocab contains U+0329 (COMBINING VERTICAL LINE BELOW), which + /// Swift merges with the preceding scalar when viewed as Characters. + private static let scalarToIndex: [Unicode.Scalar: Int32] = { + var map: [Unicode.Scalar: Int32] = [:] + for (index, scalar) in KittenTtsConstants.vocabScalars.enumerated() { + map[scalar] = Int32(index) + } + return map + }() + + /// Convert IPA phoneme strings to KittenTTS token IDs. + /// + /// Each Unicode scalar in each phoneme string is individually mapped + /// to its position in the KittenTTS vocabulary. Scalars not in the + /// vocabulary are dropped. The result is wrapped with BOS (0) and EOS (0) tokens. + /// + /// - Parameter ipaPhonemes: Array of IPA phoneme strings from G2P. + /// - Returns: Array of Int32 token IDs including BOS and EOS. + public static func tokenize(_ ipaPhonemes: [String]) -> [Int32] { + var ids: [Int32] = [KittenTtsConstants.padTokenId] // BOS + + for phoneme in ipaPhonemes { + for scalar in phoneme.unicodeScalars { + guard let id = scalarToIndex[scalar] else { continue } + guard id != KittenTtsConstants.padTokenId else { continue } + ids.append(id) + } + } + + ids.append(KittenTtsConstants.padTokenId) // EOS + return ids + } + + // MARK: - Synthesis + + /// Synthesize text to WAV audio data. + /// + /// - Parameters: + /// - text: The text to synthesize. + /// - voice: Voice identifier (e.g., "expr-voice-3-f"). + /// - speed: Speech speed multiplier (Mini only, 1.0 = normal). + /// - deEss: Whether to apply de-essing post-processing. + /// - Returns: A synthesis result containing WAV audio data. + public static func synthesize( + text: String, + voice: String = KittenTtsConstants.defaultVoice, + speed: Float = 1.0, + deEss: Bool = true + ) async throws -> SynthesisResult { + let store = try currentModelStore() + + logger.info("KittenTTS synthesizing: '\(text)'") + + // 1. Phonemize using Kokoro's G2P pipeline + let phonemes = try await phonemize(text: text) + logger.info("Phonemized to \(phonemes.count) IPA tokens") + + // 2. Tokenize to KittenTTS vocab indices + let tokenIds = tokenize(phonemes) + let realTokenCount = tokenIds.count + logger.info("Tokenized to \(realTokenCount) token IDs") + + // 3. Select appropriate model based on token count + let (model, modelVariant) = try await store.model(for: realTokenCount) + let maxTokens = modelVariant.maxTokens + logger.info("Using \(modelVariant == .fiveSecond ? "5s" : "10s") model (max \(maxTokens) tokens)") + + // 4. Load voice embedding + let voiceFloats = try await store.voiceData(for: voice) + let variant = await store.variant + + // 5. Run inference + let inferenceStart = Date() + let output: MLFeatureProvider + if variant == .nano { + output = try runNanoInference( + model: model, + tokenIds: tokenIds, + maxTokens: maxTokens, + voiceFloats: voiceFloats, + modelVariant: modelVariant + ) + } else { + output = try runMiniInference( + model: model, + tokenIds: tokenIds, + maxTokens: maxTokens, + voiceFloats: voiceFloats, + realTokenCount: realTokenCount, + speed: speed + ) + } + let inferenceElapsed = Date().timeIntervalSince(inferenceStart) + logger.info("Inference completed in \(String(format: "%.2f", inferenceElapsed))s") + + // 6. Extract audio + var samples = try extractAudio(from: output) + logger.info("Extracted \(samples.count) audio samples") + + // 7. Post-processing + if deEss { + AudioPostProcessor.applyTtsPostProcessing( + &samples, + sampleRate: Float(KittenTtsConstants.audioSampleRate), + deEssAmount: -3.0, + smoothing: false + ) + } + + // 8. Encode WAV + let audioData = try AudioWAV.data( + from: samples, + sampleRate: Double(KittenTtsConstants.audioSampleRate) + ) + + let duration = Double(samples.count) / Double(KittenTtsConstants.audioSampleRate) + logger.info("Audio duration: \(String(format: "%.2f", duration))s") + + return SynthesisResult( + audio: audioData, + samples: samples, + sampleCount: samples.count + ) + } + + // MARK: - Phonemization + + /// Phonemize text using Kokoro's G2P pipeline. + /// + /// Reuses the existing lexicon and G2P model infrastructure. + private static func phonemize(text: String) async throws -> [String] { + // Load the Kokoro lexicon/G2P models if not already loaded + try await KokoroSynthesizer.loadSimplePhonemeDictionary() + let lexicons = await KokoroSynthesizer.lexiconCache.lexicons() + let vocabulary = try await KokoroVocabulary.shared.getVocabulary() + let allowedPhonemes = Set(vocabulary.keys) + + // Chunk the text into phonemes using Kokoro's pipeline + // KittenTTS models support max 70 tokens (5s) or 140 tokens (10s) + // Use conservative 70 token limit to fit all variants + let chunks = try await KokoroChunker.chunk( + text: text, + wordToPhonemes: lexicons.word, + caseSensitiveLexicon: lexicons.caseSensitive, + customLexicon: nil, + targetTokens: 70, + hasLanguageToken: false, + allowedPhonemes: allowedPhonemes, + phoneticOverrides: [], + multilingualLanguage: nil + ) + + // Flatten all chunk phonemes + var allPhonemes: [String] = [] + for chunk in chunks { + allPhonemes.append(contentsOf: chunk.phonemes) + } + return allPhonemes + } + + // MARK: - Nano Inference + + /// Run KittenTTS Nano inference. + /// + /// Nano inputs: input_ids, attention_mask, ref_s, random_phases, source_noise + private static func runNanoInference( + model: MLModel, + tokenIds: [Int32], + maxTokens: Int, + voiceFloats: [Float], + modelVariant: ModelNames.KittenTTS.Variant + ) throws -> MLFeatureProvider { + let n = maxTokens + let t = + modelVariant == .fiveSecond + ? KittenTtsConstants.nano5sMaxSamples + : KittenTtsConstants.nano10sMaxSamples + let harmonics = KittenTtsConstants.nanoHarmonics + + // input_ids [1, N] + let inputIds = try MLMultiArray(shape: [1, NSNumber(value: n)], dataType: .int32) + let inputIdsPtr = inputIds.dataPointer.bindMemory(to: Int32.self, capacity: n) + for i in 0.. MLFeatureProvider { + let n = maxTokens + let dim = KittenTtsConstants.miniVoiceDim + + // input_ids [1, N] + let inputIds = try MLMultiArray(shape: [1, NSNumber(value: n)], dataType: .int32) + let inputIdsPtr = inputIds.dataPointer.bindMemory(to: Int32.self, capacity: n) + for i in 0.. [Float] { + guard let audioArray = output.featureValue(for: "audio")?.multiArrayValue else { + throw KittenTTSError.processingFailed("Missing 'audio' output from model") + } + guard let lengthArray = output.featureValue(for: "audio_length_samples")?.multiArrayValue + else { + throw KittenTTSError.processingFailed( + "Missing 'audio_length_samples' output from model") + } + + let validLength = Int(truncating: lengthArray[0]) + let totalLength = audioArray.count + let sampleCount = min(validLength, totalLength) + + let audioPtr = audioArray.dataPointer.bindMemory(to: Float.self, capacity: totalLength) + var samples = [Float](repeating: 0, count: sampleCount) + for i in 0.. Float { + let u1 = max(Float.random(in: 0..<1), Float.leastNormalMagnitude) + let u2 = Float.random(in: 0..<1) + return sqrt(-2.0 * log(u1)) * cos(2.0 * .pi * u2) + } +} diff --git a/Sources/FluidAudio/TTS/Kokoro/Pipeline/Synthesize/KokoroSynthesizer+ModelUtils.swift b/Sources/FluidAudio/TTS/Kokoro/Pipeline/Synthesize/KokoroSynthesizer+ModelUtils.swift index d9bea5305..707e09f02 100644 --- a/Sources/FluidAudio/TTS/Kokoro/Pipeline/Synthesize/KokoroSynthesizer+ModelUtils.swift +++ b/Sources/FluidAudio/TTS/Kokoro/Pipeline/Synthesize/KokoroSynthesizer+ModelUtils.swift @@ -17,6 +17,9 @@ extension KokoroSynthesizer { } public static func loadSimplePhonemeDictionary() async throws { + // Ensure lexicon cache file is downloaded first + try await TtsResourceDownloader.ensureLexiconFile(named: "us_lexicon_cache.json") + let cacheDir = try TtsModels.cacheDirectoryURL() let kokoroDir = cacheDir.appendingPathComponent("Models/kokoro") let vocabulary = try await KokoroVocabulary.shared.getVocabulary() diff --git a/Sources/FluidAudio/TTS/TtsBackend.swift b/Sources/FluidAudio/TTS/TtsBackend.swift index e230bc4cc..ab7b3b820 100644 --- a/Sources/FluidAudio/TTS/TtsBackend.swift +++ b/Sources/FluidAudio/TTS/TtsBackend.swift @@ -1,9 +1,19 @@ import Foundation +/// KittenTTS model variant selector. +public enum KittenTtsVariant: String, CaseIterable, Sendable { + /// KittenTTS Nano — 15M params, distilled from Kokoro-82M. + case nano + /// KittenTTS Mini — 80M params, StyleTTS2 with speed control. + case mini +} + /// Available TTS synthesis backends. public enum TtsBackend: Sendable { /// Kokoro 82M — phoneme-based, multi-voice, chunk-oriented synthesis. case kokoro /// PocketTTS — flow-matching language model, autoregressive streaming synthesis. case pocketTts + /// KittenTTS — single-shot StyleTTS2-based synthesis (Nano 15M / Mini 80M). + case kittenTts(KittenTtsVariant) } diff --git a/Sources/FluidAudioCLI/Commands/TTSCommand.swift b/Sources/FluidAudioCLI/Commands/TTSCommand.swift index b3dd4b765..5d7d69820 100644 --- a/Sources/FluidAudioCLI/Commands/TTSCommand.swift +++ b/Sources/FluidAudioCLI/Commands/TTSCommand.swift @@ -143,6 +143,7 @@ public struct TTS { var benchmarkMode = false var deEss = true var backend: TtsBackend = .kokoro + var speed: Float = 1.0 var cloneVoicePath: String? = nil var voiceFilePath: String? = nil var saveVoicePath: String? = nil @@ -200,11 +201,22 @@ public struct TTS { backend = .kokoro case "pocket", "pockettts": backend = .pocketTts + case "kitten", "kittentts": + backend = .kittenTts(.mini) // Default to Mini (82M) + case "kitten-nano", "kittennano": + backend = .kittenTts(.nano) + case "kitten-mini", "kittenmini": + backend = .kittenTts(.mini) default: logger.warning("Unknown backend '\(arguments[i + 1])'; using kokoro") } i += 1 } + case "--speed": + if i + 1 < arguments.count, let val = Float(arguments[i + 1]) { + speed = val + i += 1 + } case "--auto-download": // No-op: downloads are always ensured by the CLI () @@ -254,7 +266,7 @@ public struct TTS { return } - if backend == .pocketTts { + if case .pocketTts = backend { await runPocketTts( text: text, output: output, voice: voice, deEss: deEss, metricsPath: metricsPath, cloneVoicePath: cloneVoicePath, @@ -262,6 +274,13 @@ public struct TTS { return } + if case .kittenTts(let variant) = backend { + await runKittenTts( + text: text, output: output, voice: voice, speed: speed, + variant: variant, deEss: deEss) + return + } + do { // Timing buckets let tStart = Date() @@ -640,6 +659,62 @@ public struct TTS { } } + private static func runKittenTts( + text: String, output: String, voice: String, speed: Float, + variant: KittenTtsVariant, deEss: Bool + ) async { + do { + let tStart = Date() + let kittenVoice = + voice == TtsConstants.recommendedVoice + ? KittenTtsConstants.defaultVoice : voice + let manager = KittenTtsManager(variant: variant, defaultVoice: kittenVoice) + + let tLoad0 = Date() + try await manager.initialize() + let tLoad1 = Date() + + let tSynth0 = Date() + let wav = try await manager.synthesize( + text: text, voice: kittenVoice, speed: speed, deEss: deEss) + let tSynth1 = Date() + + let outURL = { + let expanded = (output as NSString).expandingTildeInPath + if expanded.hasPrefix("/") { + return URL(fileURLWithPath: expanded) + } + let cwd = URL( + fileURLWithPath: FileManager.default.currentDirectoryPath, + isDirectory: true) + return cwd.appendingPathComponent(expanded) + }() + try FileManager.default.createDirectory( + at: outURL.deletingLastPathComponent(), + withIntermediateDirectories: true) + try wav.write(to: outURL) + + let loadS = tLoad1.timeIntervalSince(tLoad0) + let synthS = tSynth1.timeIntervalSince(tSynth0) + let totalS = tSynth1.timeIntervalSince(tStart) + let sampleRate = Double(KittenTtsConstants.audioSampleRate) + let payload = max(0, wav.count - 44) + let audioSecs = Double(payload) / (sampleRate * 2.0) + let rtfx = synthS > 0 ? audioSecs / synthS : 0 + + logger.info("KittenTTS \(variant.rawValue) synthesis complete") + logger.info(" Load: \(String(format: "%.3f", loadS))s") + logger.info(" Synthesis: \(String(format: "%.3f", synthS))s") + logger.info(" Audio: \(String(format: "%.3f", audioSecs))s") + logger.info(" RTFx: \(String(format: "%.2f", rtfx))x") + logger.info(" Total: \(String(format: "%.3f", totalS))s") + logger.info(" Output: \(outURL.path)") + } catch { + logger.error("KittenTTS synthesis failed: \(error)") + exit(1) + } + } + private static func printUsage() { print( """ @@ -647,8 +722,9 @@ public struct TTS { Options: --output, -o Output WAV path (default: output.wav) - --voice, -v Voice name (default: af_heart for Kokoro, alba for PocketTTS) - --backend TTS backend: kokoro (default) or pocket + --voice, -v Voice name (default: af_heart for Kokoro, alba for PocketTTS, expr-voice-3-f for KittenTTS) + --backend TTS backend: kokoro (default), pocket, kitten (Mini 82M), kitten-nano, kitten-mini + --speed Speech speed multiplier (KittenTTS Mini only, default: 1.0) --lexicon, -l Custom pronunciation lexicon file (word=phonemes format, Kokoro only) --benchmark Run a predefined benchmarking suite with multiple sentences --variant Force Kokoro 5s or 15s model (values: 5s,15s) diff --git a/Tests/FluidAudioTests/TTS/KittenTTS/KittenTtsManagerTests.swift b/Tests/FluidAudioTests/TTS/KittenTTS/KittenTtsManagerTests.swift new file mode 100644 index 000000000..f8af2bcaa --- /dev/null +++ b/Tests/FluidAudioTests/TTS/KittenTTS/KittenTtsManagerTests.swift @@ -0,0 +1,88 @@ +import Testing + +@testable import FluidAudio + +@Suite("KittenTTS Manager Tests") +struct KittenTtsManagerTests { + + @Test("Manager initializes with nano variant") + func initNano() async { + let manager = KittenTtsManager(variant: .nano) + let available = await manager.isAvailable + #expect(!available) + } + + @Test("Manager initializes with mini variant") + func initMini() async { + let manager = KittenTtsManager(variant: .mini) + let available = await manager.isAvailable + #expect(!available) + } + + @Test("Synthesize throws when not initialized") + func synthesizeBeforeInit() async { + let manager = KittenTtsManager(variant: .nano) + do { + _ = try await manager.synthesize(text: "test") + Issue.record("Expected error but succeeded") + } catch { + // Expected + #expect(error is KittenTTSError) + } + } + + @Test("Default voice is expr-voice-3-f") + func defaultVoice() { + #expect(KittenTtsConstants.defaultVoice == "expr-voice-3-f") + } + + @Test("Available voices list has 8 entries") + func availableVoices() { + #expect(ModelNames.KittenTTS.availableVoices.count == 8) + } + + @Test("KittenTtsVariant cases") + func variantCases() { + #expect(KittenTtsVariant.allCases.count == 2) + #expect(KittenTtsVariant.nano.rawValue == "nano") + #expect(KittenTtsVariant.mini.rawValue == "mini") + } + + @Test("Model variant max tokens") + func modelVariantMaxTokens() { + #expect(ModelNames.KittenTTS.Variant.fiveSecond.maxTokens == 70) + #expect(ModelNames.KittenTTS.Variant.tenSecond.maxTokens == 140) + } + + @Test("Nano model filenames") + func nanoFileNames() { + let fiveS = ModelNames.KittenTTS.Variant.fiveSecond.nanoFileName() + let tenS = ModelNames.KittenTTS.Variant.tenSecond.nanoFileName() + #expect(fiveS == "kittentts_5s.mlmodelc") + #expect(tenS == "kittentts_10s.mlmodelc") + } + + @Test("Mini model filenames") + func miniFileNames() { + let fiveS = ModelNames.KittenTTS.Variant.fiveSecond.miniFileName() + let tenS = ModelNames.KittenTTS.Variant.tenSecond.miniFileName() + #expect(fiveS == "kittentts_mini_5s.mlmodelc") + #expect(tenS == "kittentts_mini_10s.mlmodelc") + } + + @Test("Repo configuration for nano") + func repoNano() { + let repo = Repo.kittenTtsNano + #expect(repo.remotePath == "alexwengg/kittentts-coreml") + #expect(repo.subPath == "nano") + #expect(repo.folderName == "kittentts-coreml/nano") + } + + @Test("Repo configuration for mini") + func repoMini() { + let repo = Repo.kittenTtsMini + #expect(repo.remotePath == "alexwengg/kittentts-coreml") + #expect(repo.subPath == "mini") + #expect(repo.folderName == "kittentts-coreml/mini") + } +} diff --git a/Tests/FluidAudioTests/TTS/KittenTTS/KittenTtsTokenizerTests.swift b/Tests/FluidAudioTests/TTS/KittenTTS/KittenTtsTokenizerTests.swift new file mode 100644 index 000000000..3619023bc --- /dev/null +++ b/Tests/FluidAudioTests/TTS/KittenTTS/KittenTtsTokenizerTests.swift @@ -0,0 +1,92 @@ +import Testing + +@testable import FluidAudio + +@Suite("KittenTTS Tokenizer Tests") +struct KittenTtsTokenizerTests { + + @Test("Vocab scalars has 178 entries") + func vocabScalarsLength() { + #expect(KittenTtsConstants.vocabScalars.count == KittenTtsConstants.vocabSize) + } + + @Test("First scalar is the padding token $") + func padToken() { + let first = KittenTtsConstants.vocabScalars.first + #expect(first == "$") + } + + @Test("Empty input produces BOS + EOS only") + func emptyInput() { + let result = KittenTtsSynthesizer.tokenize([]) + #expect(result == [0, 0]) + } + + @Test("Single IPA character is tokenized correctly") + func singleCharacter() { + // 'a' should be in the vocab at a known position (after $ + punctuation + uppercase + lowercase) + let result = KittenTtsSynthesizer.tokenize(["a"]) + #expect(result.count == 3) // BOS + 'a' + EOS + #expect(result.first == 0) // BOS + #expect(result.last == 0) // EOS + #expect(result[1] > 0) // 'a' has a non-zero ID + } + + @Test("Multiple phonemes are tokenized with BOS/EOS") + func multiplePhonemes() { + let result = KittenTtsSynthesizer.tokenize(["h", "ə", "l", "o"]) + #expect(result.first == 0) + #expect(result.last == 0) + // Should have BOS + at least some valid tokens + EOS + #expect(result.count >= 3) + } + + @Test("Unknown characters are dropped") + func unknownCharactersDropped() { + // Use characters unlikely to be in the 178-char IPA vocab + let result = KittenTtsSynthesizer.tokenize(["🎵"]) + #expect(result == [0, 0]) // Only BOS + EOS, emoji dropped + } + + @Test("Multi-character phoneme strings are split into individual scalars") + func multiCharPhoneme() { + // A phoneme like "aɪ" should be split into 'a' and 'ɪ' individually + let result = KittenTtsSynthesizer.tokenize(["aɪ"]) + #expect(result.count == 4) // BOS + 'a' + 'ɪ' + EOS + } + + @Test("Pad token is not added from input") + func padTokenNotFromInput() { + // '$' is the pad token (index 0) and should not be added as a real token + let result = KittenTtsSynthesizer.tokenize(["$"]) + #expect(result == [0, 0]) // Only BOS + EOS, '$' mapped to 0 but filtered + } + + @Test("Known IPA characters map to expected indices") + func knownCharacterMapping() { + let vocabScalars = KittenTtsConstants.vocabScalars + + // Check that 'A' maps to its position in the vocab + if let aIndex = vocabScalars.firstIndex(of: "A") { + let result = KittenTtsSynthesizer.tokenize(["A"]) + #expect(result[1] == Int32(aIndex)) + } + + // Check that 'ɑ' (IPA open back unrounded vowel) maps correctly + if let ipaIndex = vocabScalars.firstIndex(of: "\u{0251}") { + let result = KittenTtsSynthesizer.tokenize(["ɑ"]) + #expect(result[1] == Int32(ipaIndex)) + } + } + + @Test("Punctuation characters are tokenized") + func punctuationTokenized() { + let result = KittenTtsSynthesizer.tokenize(["!", ",", "."]) + // BOS + 3 punctuation chars + EOS = 5 + #expect(result.count == 5) + // All punctuation should have valid IDs (>0) + for id in result[1..<4] { + #expect(id > 0) + } + } +}