diff --git a/apps/llm/app/multimodal_llm/index.tsx b/apps/llm/app/multimodal_llm/index.tsx index 0de5004849..003c4fcb7d 100644 --- a/apps/llm/app/multimodal_llm/index.tsx +++ b/apps/llm/app/multimodal_llm/index.tsx @@ -12,6 +12,11 @@ import { View, } from 'react-native'; import { launchImageLibrary } from 'react-native-image-picker'; +import { + AudioManager, + AudioRecorder, + AudioContext, +} from 'react-native-audio-api'; import { useIsFocused } from '@react-navigation/native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; import { models, useLLM } from 'react-native-executorch'; @@ -23,12 +28,14 @@ import Spinner from '../../components/Spinner'; import { GeneratingContext } from '../../context'; import SuggestedPrompts from '../../components/SuggestedPrompts'; import ErrorBanner from '../../components/ErrorBanner'; +import AudioWaveform from '../../components/AudioWaveform'; const SUGGESTED_PROMPTS = [ "What's in this image?", 'Describe this scene in detail', 'What objects can you see?', 'What text appears in this image?', + 'Transcribe the audio', ]; import { useLLMStats } from '../../hooks/useLLMStats'; import { StatsBar } from '../../components/StatsBar'; @@ -46,12 +53,18 @@ function MultimodalLLMScreen() { const textInputRef = useRef(null); const { setGlobalGenerating } = useContext(GeneratingContext); - // Added error state - const [error, setError] = useState(null); + const [audioBuffer, setAudioBuffer] = useState(null); + const [audioLabel, setAudioLabel] = useState(null); + const [audioUrl, setAudioUrl] = useState(''); + const [isFetchingAudio, setIsFetchingAudio] = useState(false); + const [isRecording, setIsRecording] = useState(false); + const [hasMicPermission, setHasMicPermission] = useState(false); + const recorder = useRef(new AudioRecorder()); + const recordChunks = useRef([]); - const vlm = useLLM({ - model: models.llm.lfm2_5_vl_1_6b(), - }); + const [error, setError] = useState(null); + const model = models.llm.gemma4_e2b_multimodal(); + const vlm = useLLM({ model: model }); const tokenCount = vlm.isReady ? vlm.getGeneratedTokenCount() : 0; const { stats, onMessageSend } = useLLMStats( vlm.response, @@ -68,6 +81,95 @@ function MultimodalLLMScreen() { if (vlm.error) setError(String(vlm.error)); }, [vlm.error]); + useEffect(() => { + AudioManager.setAudioSessionOptions({ + iosCategory: 'playAndRecord', + iosMode: 'spokenAudio', + iosOptions: ['allowBluetoothHFP', 'defaultToSpeaker'], + }); + (async () => { + const status = await AudioManager.requestRecordingPermissions(); + setHasMicPermission(status === 'Granted'); + })(); + + return () => { + if (vlm.isGenerating) vlm.interrupt(); + // eslint-disable-next-line react-hooks/exhaustive-deps + recorder.current.stop(); + AudioManager.setAudioSessionActivity(false); + }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const loadAudioFromUrl = async () => { + const url = audioUrl.trim(); + if (!url) return; + setIsFetchingAudio(true); + try { + const ctx = new AudioContext({ sampleRate: 16000 }); + const decoded = await ctx.decodeAudioData(url); + const pcm = decoded.getChannelData(0); + const name = url.split('/').pop() || 'audio'; + setAudioBuffer(pcm); + setAudioLabel(`${name} ยท ${(pcm.length / 16000).toFixed(1)}s`); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } finally { + setIsFetchingAudio(false); + } + }; + + const startRecording = async () => { + if (!hasMicPermission) { + setError('Microphone permission denied. Please enable it in Settings.'); + return; + } + recordChunks.current = []; + const sampleRate = 16000; + recorder.current.onAudioReady( + { sampleRate, bufferLength: 0.1 * sampleRate, channelCount: 1 }, + ({ buffer }) => { + recordChunks.current.push(new Float32Array(buffer.getChannelData(0))); + } + ); + try { + const ok = await AudioManager.setAudioSessionActivity(true); + if (!ok) { + setError('Cannot start audio session'); + return; + } + const result = recorder.current.start(); + if (result.status === 'error') { + setError(`Recording problems: ${result.message}`); + return; + } + setIsRecording(true); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } + }; + + const stopRecording = () => { + recorder.current.stop(); + setIsRecording(false); + const total = recordChunks.current.reduce((n, c) => n + c.length, 0); + if (total === 0) return; + const pcm = new Float32Array(total); + let off = 0; + for (const c of recordChunks.current) { + pcm.set(c, off); + off += c.length; + } + recordChunks.current = []; + setAudioBuffer(pcm); + setAudioLabel(`Recording ยท ${(pcm.length / 16000).toFixed(1)}s`); + }; + + const clearAudio = () => { + setAudioBuffer(null); + setAudioLabel(null); + }; + const pickImage = async () => { try { const result = await launchImageLibrary({ mediaType: 'photo' }); @@ -81,19 +183,27 @@ function MultimodalLLMScreen() { }; const sendMessage = async () => { - if (!userInput.trim() || vlm.isGenerating) return; + if (!(imageUri || audioBuffer || userInput.trim()) || vlm.isGenerating) + return; onMessageSend(); const text = userInput.trim(); setUserInput(''); textInputRef.current?.clear(); Keyboard.dismiss(); const currentImageUri = imageUri; + const currentAudio = audioBuffer; setImageUri(null); + setAudioBuffer(null); + setAudioLabel(null); try { - await vlm.sendMessage( - text, - currentImageUri ? { imagePath: currentImageUri } : undefined - ); + const media = + currentImageUri || currentAudio + ? { + ...(currentImageUri ? { imagePath: currentImageUri } : {}), + ...(currentAudio ? { audioBuffer: currentAudio } : {}), + } + : undefined; + await vlm.sendMessage(text, media); } catch (e) { // Updated to set UI error instead of just console.error setError(e instanceof Error ? e.message : String(e)); @@ -135,7 +245,9 @@ function MultimodalLLMScreen() { Hello! ๐Ÿ‘‹ - Pick an image and ask me anything about it. + {model.capabilities.find((c) => c === 'audio') + ? 'Say hi, or pick an image, and ask me anything about it.' + : 'Pick an image and ask me anything about it.'} )} + {/* Audio URL input */} + + + + + {isFetchingAudio ? 'โ€ฆ' : 'Load'} + + + + + {/* Audio attachment strip */} + {audioLabel && ( + + + ๐ŸŽต {audioLabel} + + โœ• + + + + + )} + ๐Ÿ“ท + {/* Mic record / stop button */} + + + {isRecording ? 'โน๏ธ' : '๐ŸŽค'} + + + - {userInput.trim() && !vlm.isGenerating && ( - - - - )} + {(imageUri || audioBuffer || userInput.trim()) && + !vlm.isGenerating && ( + + + + )} {vlm.isGenerating && ( ; +} + +const NUM_BARS = 32; + +export default function AudioWaveform({ buffer, style }: AudioWaveformProps) { + const bars = useMemo(() => { + if (!buffer || buffer.length === 0) return null; + const chunkSize = Math.max(1, Math.floor(buffer.length / NUM_BARS)); + const peaks: number[] = []; + let max = 0; + for (let i = 0; i < NUM_BARS; i++) { + const start = i * chunkSize; + const end = Math.min(start + chunkSize, buffer.length); + let peak = 0; + for (let j = start; j < end; j++) { + const v = Math.abs(buffer[j] ?? 0); + if (v > peak) peak = v; + } + peaks.push(peak); + if (peak > max) max = peak; + } + return max > 0 ? peaks.map((p) => p / max) : peaks; + }, [buffer]); + + if (!bars) return null; + + return ( + + {bars.map((amp, i) => ( + + ))} + + ); +} + +const styles = StyleSheet.create({ + container: { + flexDirection: 'row', + alignItems: 'center', + height: 16, + minWidth: 160, + gap: 2, + }, + bar: { + flex: 1, + borderRadius: 1, + backgroundColor: ColorPalette.blueDark, + opacity: 0.35, + }, +}); diff --git a/apps/llm/components/MessageItem.tsx b/apps/llm/components/MessageItem.tsx index 2c44714ac0..cda8609885 100644 --- a/apps/llm/components/MessageItem.tsx +++ b/apps/llm/components/MessageItem.tsx @@ -11,6 +11,7 @@ import MarkdownComponent from './MarkdownComponent'; import LlamaIcon from '../assets/icons/llama_icon.svg'; import ColorPalette from '../colors'; import { Message } from 'react-native-executorch'; +import AudioWaveform from './AudioWaveform'; interface MessageItemProps { message: Message; @@ -43,6 +44,12 @@ const MessageItem = memo(({ message, deleteMessage }: MessageItemProps) => { resizeMode="contain" /> )} + {message.audioWaveform && ( + + )} @@ -103,6 +110,9 @@ const styles = StyleSheet.create({ borderRadius: 6, marginBottom: 6, }, + userMessageWaveform: { + marginBottom: 6, + }, aiMessageIconContainer: { backgroundColor: ColorPalette.seaBlueLight, height: 32, diff --git a/apps/llm/components/llmModels.ts b/apps/llm/components/llmModels.ts index 1d80d7a395..1991578973 100644 --- a/apps/llm/components/llmModels.ts +++ b/apps/llm/components/llmModels.ts @@ -10,6 +10,8 @@ const llm = models.llm; export type LLMModelSources = LLMProps['model']; export const LLM_MODELS: ModelOption[] = [ + // Gemma4 + { label: 'Gemma4 E2B', value: llm.gemma4_e2b() }, // Llama 3.2 { label: 'Llama 3.2 1B', diff --git a/docs/docs/03-hooks/01-natural-language-processing/useLLM.md b/docs/docs/03-hooks/01-natural-language-processing/useLLM.md index 7b1cb25158..29b1be4d72 100644 --- a/docs/docs/03-hooks/01-natural-language-processing/useLLM.md +++ b/docs/docs/03-hooks/01-natural-language-processing/useLLM.md @@ -56,7 +56,7 @@ The code snippet above fetches the model from the specified URL, loads it into m `useLLM` takes [`LLMProps`](../../06-api-reference/interfaces/LLMProps.md) that consists of: -- [model source](../../06-api-reference/interfaces/LLMProps.md#modelsource), [tokenizer source](../../06-api-reference/interfaces/LLMProps.md#tokenizersource), and [tokenizer config source](../../06-api-reference/interfaces/LLMProps.md#tokenizerconfigsource). +- [model](../../06-api-reference/interfaces/LLMModel.md). - An optional flag [`preventLoad`](../../06-api-reference/interfaces/SpeechToTextProps.md#preventload) which prevents auto-loading of the model. You need more details? Check the following resources: @@ -494,13 +494,13 @@ Depending on selected model and the user's device generation speed can be above ## Vision-Language Models (VLM) -Some models support multimodal input โ€” text and images together. To use them, pass a `capabilities` array when loading the model. +Some models support multimodal input โ€” text, images and/or audio together. To use them, pass a `capabilities` array when loading the model. ### Loading a VLM ```tsx import { models, useLLM } from 'react-native-executorch'; -const llm = useLLM({ model: models.llm.lfm2_5_vl_1_6b() }); +const llm = useLLM({ model: models.llm.gemma4_e2b_multimodal() }); ``` The `capabilities` field is already set on the model constant. You can also construct the model object explicitly: @@ -511,22 +511,26 @@ const llm = useLLM({ modelSource: '...', tokenizerSource: '...', tokenizerConfigSource: '...', - capabilities: ['vision'], + capabilities: ['vision', 'audio'], }, }); ``` Passing `capabilities` unlocks the typed `media` argument on `sendMessage`. -### Sending a message with an image +### Sending a message with an image or audio recording ```tsx -const llm = useLLM({ model: models.llm.lfm2_5_vl_1_6b() }); +const llm = useLLM({ model: models.llm.gemma4_e2b_multimodal() }); const send = () => { llm.sendMessage('What is in this image?', { imagePath: '/path/to/image.jpg', }); + // or + llm.sendMessage('What can you hear?', { + audioBuffer: audioRecording, + }); }; return ( @@ -538,6 +542,7 @@ return ( ``` The `imagePath` should be a local file path on the device. +The `audioBuffer` should be a `Float32Array` with 16kHz waveform. ### Functional generation with images diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md index 967625160c..48b3395f87 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md @@ -118,12 +118,12 @@ Model presets expose an optional `generationConfig` that `LLMModule.fromModelNam ## Vision-Language Models (VLM) -Some models support multimodal input โ€” text and images together. To use them, pass `capabilities` in the model object when calling [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname): +Some models support multimodal input โ€” text, images and/or audio together. To use them, pass `capabilities` in the model object when calling [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname): ```typescript import { models, LLMModule } from 'react-native-executorch'; const llm = await LLMModule.fromModelName( - models.llm.lfm2_5_vl_1_6b(), + models.llm.gemma4_e2b_multimodal(), undefined, (token) => console.log(token) ); @@ -133,20 +133,24 @@ The `capabilities` field is already set on the model constant. You can also cons ```typescript const llm = await LLMModule.fromModelName({ - modelName: 'lfm2.5-vl-1.6b-quantized', + modelName: 'gemma4-e2b-multimodal', modelSource: require('./path/to/model.pte'), tokenizerSource: require('./path/to/tokenizer.json'), tokenizerConfigSource: require('./path/to/tokenizer_config.json'), - capabilities: ['vision'], + capabilities: ['vision', 'audio'], }); ``` -Once loaded, pass `imagePath` to [`sendMessage`](../../06-api-reference/classes/LLMModule.md#sendmessage): +Once loaded, pass `imagePath` or `audioBuffer` to [`sendMessage`](../../06-api-reference/classes/LLMModule.md#sendmessage): ```typescript const response = await llm.sendMessage('What is in this image?', { imagePath: '/path/to/image.jpg', }); +// or +const response = await llm.sendMessage('What can you hear?', { + audioBuffer: audioRecording, //expected as waveform 16kHz +}); ``` Or use [`generate`](../../06-api-reference/classes/LLMModule.md#generate) with `mediaPath` on the message: @@ -159,7 +163,14 @@ const chat: Message[] = [ mediaPath: '/path/to/image.jpg', }, ]; - +// or +const chat: Message[] = [ + { + role: 'user', + content: 'Transcribe the recording.', + audioWaveform: audioRecording, + }, +]; const response = await llm.generate(chat); ``` diff --git a/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useLLM.md b/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useLLM.md index 7b1cb25158..f19920a486 100644 --- a/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useLLM.md +++ b/docs/versioned_docs/version-0.9.x/03-hooks/01-natural-language-processing/useLLM.md @@ -494,13 +494,13 @@ Depending on selected model and the user's device generation speed can be above ## Vision-Language Models (VLM) -Some models support multimodal input โ€” text and images together. To use them, pass a `capabilities` array when loading the model. +Some models support multimodal input โ€” text, images and/or audio together. To use them, pass a `capabilities` array when loading the model. ### Loading a VLM ```tsx import { models, useLLM } from 'react-native-executorch'; -const llm = useLLM({ model: models.llm.lfm2_5_vl_1_6b() }); +const llm = useLLM({ model: models.llm.gemma4_e2b_multimodal() }); ``` The `capabilities` field is already set on the model constant. You can also construct the model object explicitly: @@ -511,22 +511,26 @@ const llm = useLLM({ modelSource: '...', tokenizerSource: '...', tokenizerConfigSource: '...', - capabilities: ['vision'], + capabilities: ['vision', 'audio'], }, }); ``` Passing `capabilities` unlocks the typed `media` argument on `sendMessage`. -### Sending a message with an image +### Sending a message with an image or audio recording ```tsx -const llm = useLLM({ model: models.llm.lfm2_5_vl_1_6b() }); +const llm = useLLM({ model: models.llm.gemma4_e2b_multimodal() }); const send = () => { llm.sendMessage('What is in this image?', { imagePath: '/path/to/image.jpg', }); + // or + llm.sendMessage('What can you hear?', { + audioBuffer: audioRecording, + }); }; return ( @@ -538,6 +542,7 @@ return ( ``` The `imagePath` should be a local file path on the device. +The `audioBuffer` should be a `Float32Array` with 16kHz waveform. ### Functional generation with images diff --git a/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/LLMModule.md b/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/LLMModule.md index 967625160c..48b3395f87 100644 --- a/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/versioned_docs/version-0.9.x/04-typescript-api/01-natural-language-processing/LLMModule.md @@ -118,12 +118,12 @@ Model presets expose an optional `generationConfig` that `LLMModule.fromModelNam ## Vision-Language Models (VLM) -Some models support multimodal input โ€” text and images together. To use them, pass `capabilities` in the model object when calling [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname): +Some models support multimodal input โ€” text, images and/or audio together. To use them, pass `capabilities` in the model object when calling [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname): ```typescript import { models, LLMModule } from 'react-native-executorch'; const llm = await LLMModule.fromModelName( - models.llm.lfm2_5_vl_1_6b(), + models.llm.gemma4_e2b_multimodal(), undefined, (token) => console.log(token) ); @@ -133,20 +133,24 @@ The `capabilities` field is already set on the model constant. You can also cons ```typescript const llm = await LLMModule.fromModelName({ - modelName: 'lfm2.5-vl-1.6b-quantized', + modelName: 'gemma4-e2b-multimodal', modelSource: require('./path/to/model.pte'), tokenizerSource: require('./path/to/tokenizer.json'), tokenizerConfigSource: require('./path/to/tokenizer_config.json'), - capabilities: ['vision'], + capabilities: ['vision', 'audio'], }); ``` -Once loaded, pass `imagePath` to [`sendMessage`](../../06-api-reference/classes/LLMModule.md#sendmessage): +Once loaded, pass `imagePath` or `audioBuffer` to [`sendMessage`](../../06-api-reference/classes/LLMModule.md#sendmessage): ```typescript const response = await llm.sendMessage('What is in this image?', { imagePath: '/path/to/image.jpg', }); +// or +const response = await llm.sendMessage('What can you hear?', { + audioBuffer: audioRecording, //expected as waveform 16kHz +}); ``` Or use [`generate`](../../06-api-reference/classes/LLMModule.md#generate) with `mediaPath` on the message: @@ -159,7 +163,14 @@ const chat: Message[] = [ mediaPath: '/path/to/image.jpg', }, ]; - +// or +const chat: Message[] = [ + { + role: 'user', + content: 'Transcribe the recording.', + audioWaveform: audioRecording, + }, +]; const response = await llm.generate(chat); ``` diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index f94ef918ac..e4209b2f79 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -17,6 +18,7 @@ #include #include +#include #include #include #include @@ -223,6 +225,22 @@ inline std::vector getValue>(const jsi::Value &val, return getArrayAsVector(val, runtime); } +template <> +inline std::vector> +getValue>>(const jsi::Value &val, + jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + const size_t length = array.size(runtime); + std::vector> result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + jsi::Value element = array.getValueAtIndex(runtime, i); + auto span = getTypedArrayAsSpan(element, runtime); + result.emplace_back(span.begin(), span.end()); + } + return result; +} + template <> inline std::vector getValue>(const jsi::Value &val, jsi::Runtime &runtime) { @@ -302,6 +320,31 @@ getValue>(const jsi::Value &val, jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } +template <> +inline models::llm::MultimodalInputs +getValue(const jsi::Value &val, + jsi::Runtime &runtime) { + models::llm::MultimodalInputs multimodalInputs; + jsi::Object obj = val.asObject(runtime); + + jsi::Value v = obj.getProperty(runtime, "imageToken"); + if (!v.isUndefined() && !v.isNull()) { + auto &images = multimodalInputs.images.emplace(); + images.token = getValue(v, runtime); + v = obj.getProperty(runtime, "imagePaths"); + images.paths = getValue>(v, runtime); + } + v = obj.getProperty(runtime, "audioToken"); + if (!v.isUndefined() && !v.isNull()) { + auto &audios = multimodalInputs.audios.emplace(); + audios.token = getValue(v, runtime); + v = obj.getProperty(runtime, "audioWaveforms"); + audios.waveforms = getValue>>(v, runtime); + } + + return multimodalInputs; +} + // Conversion from C++ types to jsi -------------------------------------------- // Implementation functions might return any type, but in a promise we can only diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 7e0fa4b26e..924bba9f99 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -1,11 +1,12 @@ #include "LLM.h" +#include "rnexecutorch/models/llm/Types.h" #include #include #include #include -#include #include +#include #include #include #include @@ -21,7 +22,6 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, std::vector capabilities, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker, Module::LoadMode::Mmap) { - if (capabilities.empty()) { runner_ = std::make_unique(std::move(module_), tokenizerSource); @@ -31,6 +31,9 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, if (cap == "vision") { encoders[llm::MultimodalType::Image] = std::make_unique(*module_); + } else if (cap == "audio") { + encoders[llm::MultimodalType::Audio] = + std::make_unique(*module_); } } runner_ = std::make_unique( @@ -75,62 +78,73 @@ std::string LLM::generate(std::string input, } std::string LLM::generateMultimodal(std::string prompt, - std::vector imagePaths, - std::string imageToken, - std::shared_ptr callback) { + std::shared_ptr callback, + MultimodalInputs mutlimodalInputs) { if (!runner_ || !runner_->is_loaded()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Runner is not loaded"); } if (!runner_->is_multimodal()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "This model does not support multimodal input. Use generate(prompt, " - "callback) for text-only generation."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "This model does not support multimodal input."); } - if (imageToken.empty()) { + if (!mutlimodalInputs.images.has_value() && + !mutlimodalInputs.audios.has_value()) { throw RnExecutorchError( RnExecutorchErrorCode::InvalidUserInput, - "imageToken must not be empty. Pass the model's image token (e.g. " - "from tokenizer_config.json)."); + "At least one of imageToken/audioToken must be non-empty"); } - const size_t kImageTokenLen = imageToken.size(); - + // Scan the prompt once, splitting at the earliest placeholder at each step + // so that image/audio placeholders can be freely interleaved in the prompt. std::vector inputs; - size_t imageIdx = 0; - size_t searchPos = 0; - - while (true) { - size_t found = prompt.find(imageToken, searchPos); - if (found == std::string::npos) { - if (searchPos < prompt.size()) { - inputs.push_back(llm::make_text_input(prompt.substr(searchPos))); - } + size_t imageIdx = 0, audioIdx = 0, pos = 0; + while (pos < prompt.size()) { + size_t imgAt = mutlimodalInputs.images.has_value() + ? prompt.find(mutlimodalInputs.images.value().token, pos) + : std::string::npos; + size_t audAt = mutlimodalInputs.audios.has_value() + ? prompt.find(mutlimodalInputs.audios.value().token, pos) + : std::string::npos; + if (imgAt == std::string::npos && audAt == std::string::npos) { + inputs.push_back(llm::make_text_input(prompt.substr(pos))); break; } - // Text segment before this placeholder - if (found > searchPos) { - inputs.push_back( - llm::make_text_input(prompt.substr(searchPos, found - searchPos))); + const bool imageFirst = imgAt != std::string::npos && + (audAt == std::string::npos || imgAt < audAt); + size_t at = imageFirst ? imgAt : audAt; + if (at > pos) { + inputs.push_back(llm::make_text_input(prompt.substr(pos, at - pos))); } - // Image at this position - if (imageIdx >= imagePaths.size()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "More '" + imageToken + - "' placeholders in prompt than image paths provided"); + if (imageFirst) { + auto &images = mutlimodalInputs.images.value(); + if (imageIdx >= images.paths.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More '" + images.token + + "' placeholders than image paths"); + } + inputs.push_back(llm::make_image_input(images.paths[imageIdx++])); + pos = at + images.token.size(); + } else { + auto &audios = mutlimodalInputs.audios.value(); + if (audioIdx >= audios.waveforms.size()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More '" + audios.token + + "' placeholders than audio waveforms"); + } + inputs.push_back( + llm::make_audio_input(std::move(audios.waveforms[audioIdx++]))); + pos = at + audios.token.size(); } - inputs.push_back(llm::make_image_input(imagePaths[imageIdx++])); - searchPos = found + kImageTokenLen; } - - if (imageIdx < imagePaths.size()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "More image paths provided than '" + imageToken + - "' placeholders in prompt"); + if ((mutlimodalInputs.images.has_value() && + imageIdx < mutlimodalInputs.images.value().paths.size()) || + (mutlimodalInputs.audios.has_value() && + audioIdx < mutlimodalInputs.audios.value().waveforms.size())) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidUserInput, + "More image/audio paths provided than placeholders in prompt"); } - if (inputs.empty()) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "No inputs to generate from"); @@ -150,7 +164,6 @@ std::string LLM::generateMultimodal(std::string prompt, if (error != Error::Ok) { throw RnExecutorchError(error, "Failed to generate multimodal response"); } - return output; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 222b5bc62f..4b7087351b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace rnexecutorch { @@ -22,10 +23,10 @@ class LLM : public BaseModel { std::string generate(std::string prompt, std::shared_ptr callback); + std::string generateMultimodal(std::string prompt, - std::vector imagePaths, - std::string imageToken, - std::shared_ptr callback); + std::shared_ptr callback, + MultimodalInputs mutlimodalInputs = {}); void interrupt(); void reset(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/Types.h new file mode 100644 index 0000000000..921d4fa8f4 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/Types.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include + +namespace rnexecutorch::models::llm { +struct ImageInputs { + std::vector paths; + std::string token; +}; + +struct AudioInputs { + std::vector> waveforms; + std::string token; +}; + +struct MultimodalInputs { + std::optional images; + std::optional audios; +}; + +} // namespace rnexecutorch::models::llm diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 1f34b3a18e..5f9d7287a5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -293,6 +293,7 @@ add_rn_test(LLMTests integration/LLMTest.cpp ${COMMON_DIR}/runner/sampler.cpp ${COMMON_DIR}/runner/arange_util.cpp ${COMMON_DIR}/runner/encoders/vision_encoder.cpp + ${COMMON_DIR}/runner/encoders/audio_encoder.cpp ${IMAGE_UTILS_SOURCES} LIBS tokenizers_deps opencv_deps ) diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp index ae0a11e777..4b34f4248e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp @@ -1,11 +1,15 @@ #include "BaseModelTests.h" +#include "utils/TestUtils.h" #include #include #include +#include +#include #include #include #include +#include #include using namespace rnexecutorch; @@ -30,6 +34,12 @@ std::string formatChatML(const std::string &systemPrompt, "<|im_start|>assistant\n"; } +// Helper to format a single-turn prompt in Gemma's chat template. +std::string formatGemma(const std::string &userMessage) { + return "user\n" + userMessage + "\n" + + "model\n"; +} + // ============================================================================ // Common tests via typed test suite // ============================================================================ @@ -227,6 +237,18 @@ TEST(VisionEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { EXPECT_THROW(encoder->load(), rnexecutorch::RnExecutorchError); } +TEST(AudioEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { + // smolLm2_135M_8da4w.pte has no audio_encoder method + auto module = std::make_unique<::executorch::extension::Module>( + "smolLm2_135M_8da4w.pte", + ::executorch::extension::Module::LoadMode::File); + + auto encoder = + std::make_unique(*module); + + EXPECT_THROW(encoder->load(), rnexecutorch::RnExecutorchError); +} + // ============================================================================ // VLM-specific tests // ============================================================================ @@ -243,7 +265,11 @@ TEST_F(LLMTest, TextModelIsNotMultimodal) { TEST_F(LLMTest, GenerateMultimodalOnTextModelThrows) { LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); - EXPECT_THROW(model.generateMultimodal("hello", {}, "", nullptr), + // A text-only runner reports is_multimodal() == false, so any multimodal + // call must be rejected before the inputs are even inspected. + MultimodalInputs inputs{.images = + ImageInputs{.paths = {}, .token = ""}}; + EXPECT_THROW(model.generateMultimodal("hello", nullptr, std::move(inputs)), RnExecutorchError); } @@ -270,22 +296,120 @@ std::shared_ptr VLMTest::invoker_; std::unique_ptr VLMTest::model_; TEST_F(VLMTest, GenerateMultimodalEmptyImageTokenThrows) { - EXPECT_THROW( - model_->generateMultimodal("hello", {kTestImagePath}, "", nullptr), - RnExecutorchError); + MultimodalInputs inputs{ + .images = ImageInputs{.paths = {kTestImagePath}, .token = ""}}; + EXPECT_THROW(model_->generateMultimodal("hello", nullptr, std::move(inputs)), + RnExecutorchError); } TEST_F(VLMTest, GenerateMultimodalMorePlaceholdersThanImagePaths) { std::string prompt = std::string(kVlmImageToken) + " and " + kVlmImageToken; - EXPECT_THROW(model_->generateMultimodal(prompt, {kTestImagePath}, - kVlmImageToken, nullptr), + MultimodalInputs inputs{.images = ImageInputs{.paths = {kTestImagePath}, + .token = kVlmImageToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), RnExecutorchError); } TEST_F(VLMTest, GenerateMultimodalMoreImagePathsThanPlaceholders) { std::string prompt = std::string(kVlmImageToken) + " describe"; - EXPECT_THROW(model_->generateMultimodal(prompt, - {kTestImagePath, kTestImagePath}, - kVlmImageToken, nullptr), + MultimodalInputs inputs{ + .images = ImageInputs{.paths = {kTestImagePath, kTestImagePath}, + .token = kVlmImageToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), + RnExecutorchError); +} + +// ============================================================================ +// Audio (Gemma 4) multimodal tests +// ============================================================================ +constexpr auto kGemmaModelPath = "gemma4_e2b_mm_xnnpack.pte"; +constexpr auto kGemmaTokenizerPath = "gemma_tokenizer.json"; +constexpr auto kGemmaAudioToken = ""; +constexpr auto kTestAudioPath = "test_audio_float.raw"; + +// Fixture that loads the audio-capable Gemma model once for all audio tests. +class GemmaAudioTest : public ::testing::Test { +protected: + static void SetUpTestSuite() { + invoker_ = createMockCallInvoker(); + model_ = std::make_unique(kGemmaModelPath, kGemmaTokenizerPath, + std::vector{"vision", "audio"}, + invoker_); + } + + static void TearDownTestSuite() { + model_.reset(); + invoker_.reset(); + } + + static std::vector loadAudio(size_t maxSamples = 32000) { + auto wav = test_utils::loadAudioFromFile(kTestAudioPath); + if (wav.size() > maxSamples) { + wav.resize(maxSamples); + } + return wav; + } + + static std::shared_ptr invoker_; + static std::unique_ptr model_; +}; + +std::shared_ptr GemmaAudioTest::invoker_; +std::unique_ptr GemmaAudioTest::model_; + +TEST_F(GemmaAudioTest, GenerateMultimodalNoInputsThrows) { + EXPECT_THROW(model_->generateMultimodal("hello", nullptr, {}), + RnExecutorchError); +} + +TEST_F(GemmaAudioTest, GenerateMultimodalEmptyAudioTokenThrows) { + MultimodalInputs inputs{ + .audios = AudioInputs{.waveforms = {loadAudio()}, .token = ""}}; + EXPECT_THROW(model_->generateMultimodal("hello", nullptr, std::move(inputs)), + RnExecutorchError); +} + +TEST_F(GemmaAudioTest, GenerateMultimodalMorePlaceholdersThanWaveformsThrows) { + std::string prompt = + std::string(kGemmaAudioToken) + " and " + kGemmaAudioToken; + MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {loadAudio()}, + .token = kGemmaAudioToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), RnExecutorchError); } + +TEST_F(GemmaAudioTest, GenerateMultimodalMoreWaveformsThanPlaceholdersThrows) { + std::string prompt = std::string(kGemmaAudioToken) + " describe"; + MultimodalInputs inputs{ + .audios = AudioInputs{.waveforms = {loadAudio(), loadAudio()}, + .token = kGemmaAudioToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), + RnExecutorchError); +} + +TEST_F(GemmaAudioTest, GenerateMultimodalAudioProducesOutput) { + std::vector wav = loadAudio(); + ASSERT_FALSE(wav.empty()) + << "test_audio_float.raw missing on device - check run_tests.sh assets"; + + std::string prompt = + formatGemma(std::string(kGemmaAudioToken) + " Transcribe the audio."); + MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {std::move(wav)}, + .token = kGemmaAudioToken}}; + std::string output = + model_->generateMultimodal(prompt, nullptr, std::move(inputs)); + + EXPECT_FALSE(output.empty()); + EXPECT_GT(model_->getGeneratedTokenCount(), 0); +} + +TEST_F(GemmaAudioTest, GenerateMultimodalInterleavedTextAndAudio) { + std::string prompt = formatGemma("Listen: " + std::string(kGemmaAudioToken) + + " then summarise it."); + MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {loadAudio()}, + .token = kGemmaAudioToken}}; + std::string output = + model_->generateMultimodal(prompt, nullptr, std::move(inputs)); + + EXPECT_FALSE(output.empty()); +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh index 2240a7b667..fda1cce6d6 100755 --- a/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh +++ b/packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh @@ -22,6 +22,7 @@ HF_VERSION_TAG="resolve/v${LIB_VERSION}" TEST_EXECUTABLES=( "NumericalTests" "RunnerTests" + "SamplerTests" "LogTests" "FileUtilsTest" "ImageProcessingTest" @@ -86,6 +87,8 @@ MODELS=( "lfm2_5_vl_quantized_xnnpack_v2.pte|https://huggingface.co/software-mansion/react-native-executorch-lfm-2.5/${HF_VERSION_TAG}/vl_1_6b/xnnpack/lfm_2_5_vl_1_6b_xnnpack_8da4w.pte" "lfm2_vl_tokenizer.json|https://huggingface.co/software-mansion/react-native-executorch-lfm-2.5/${HF_VERSION_TAG}/vl_1_6b/tokenizer.json" "lfm2_vl_tokenizer_config.json|https://huggingface.co/software-mansion/react-native-executorch-lfm-2.5/${HF_VERSION_TAG}/vl_1_6b/tokenizer_config.json" + "gemma4_e2b_mm_xnnpack.pte|https://huggingface.co/software-mansion/react-native-executorch-gemma-4-multimodal/${HF_VERSION_TAG}/e2b/xnnpack/gemma_4_e2b_xnnpack_8da4w.pte" + "gemma_tokenizer.json|https://huggingface.co/software-mansion/react-native-executorch-gemma-4/${HF_VERSION_TAG}/e2b/tokenizer.json" "yolo26n-seg.pte|https://huggingface.co/software-mansion/react-native-executorch-yolo26-seg/${HF_VERSION_TAG}/n/xnnpack/yolo26_seg_n_xnnpack_fp32.pte" "segmentation_image.jpg|https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/Collage_audi.jpg/1280px-Collage_audi.jpg" "yolo26n-pose.pte|https://huggingface.co/software-mansion/react-native-executorch-yolo26-pose/${HF_VERSION_TAG}/xnnpack/yolo26_pose_n_xnnpack_fp32.pte" @@ -212,7 +215,7 @@ models_for_test() { TokenizerModuleTests) echo "tokenizer.json" ;; SpeechToTextTests) echo "whisper_tiny_en_xnnpack.pte whisper_tokenizer.json fsmn-vad_xnnpack.pte" ;; TextToSpeechTests) echo "kokoro_duration_predictor.pte kokoro_synthesizer.pte kokoro_af_heart.bin kokoro_us_lexicon.json kokoro_en_tagger.json kokoro_us_phonemizer.pte" ;; - LLMTests) echo "smolLm2_135M_8da4w.pte smollm_tokenizer.json lfm2_5_vl_quantized_xnnpack_v2.pte lfm2_vl_tokenizer.json lfm2_vl_tokenizer_config.json test_image.jpg" ;; + LLMTests) echo "smolLm2_135M_8da4w.pte smollm_tokenizer.json lfm2_5_vl_quantized_xnnpack_v2.pte lfm2_vl_tokenizer.json lfm2_vl_tokenizer_config.json test_image.jpg gemma4_e2b_mm_xnnpack.pte gemma_tokenizer.json" ;; TextToImageTests) echo "t2i_tokenizer.json t2i_encoder.pte t2i_unet.pte t2i_decoder.pte" ;; InstanceSegmentationTests) echo "yolo26n-seg.pte segmentation_image.jpg" ;; PoseEstimationTests) echo "yolo26n-pose.pte" ;; diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp index 4295f16232..bf7a1d02d6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include @@ -26,27 +27,39 @@ std::vector sampleMany(Sampler &s, std::vector logits, // 1. Repetition penalty on positive logit: token 0 should be sampled less. TEST(SamplerTest, RepetitionPenaltyReducesPositiveLogit) { - Sampler s(2, 1.0f, 1.0f, 0, 0.0f, 1.3f); + Sampler s(2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.3f}); std::vector logits = {1.0f, 1.0f}; std::vector recent = {0}; auto counts = sampleMany(s, logits, recent, 2000); EXPECT_LT(counts[0], 1200); } -// 2. Repetition penalty on negative logit: penalised token should appear even -// less. +// 2. Repetition penalty on negative logit: multiplying a negative logit by the +// penalty makes it more negative, so the penalised token is sampled strictly +// less often than without the penalty. Compare against an unpenalised baseline +// rather than a fixed threshold: with penalty 1.5 the penalised logit is +// -1.0 * 1.5 = -1.5, giving P(token 1) = e^-1.5 / (1 + e^-1.5) โ‰ˆ 0.18 (~365 of +// 2000) versus the baseline e^-1 / (1 + e^-1) โ‰ˆ 0.27 (~538). A static "< 200" +// bound would be mathematically unreachable at this penalty. TEST(SamplerTest, RepetitionPenaltyMultipliesNegativeLogit) { - Sampler s(2, 1.0f, 1.0f, 0, 0.0f, 1.5f); - std::vector logits = {0.0f, -1.0f}; + Sampler baseline( + 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.0f}); + Sampler penalised( + 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.5f}); + std::vector logits_b = {0.0f, -1.0f}; + std::vector logits_p = {0.0f, -1.0f}; std::vector recent = {1}; - auto counts = sampleMany(s, logits, recent, 2000); - EXPECT_LT(counts[1], 200); + auto cb = sampleMany(baseline, logits_b, recent, 2000); + auto cp = sampleMany(penalised, logits_p, recent, 2000); + EXPECT_LT(cp[1], cb[1]); } // 3. No recent tokens โ€” penalty has no effect. TEST(SamplerTest, RepetitionPenaltyNoRecentTokensHasNoEffect) { - Sampler baseline(2, 1.0f, 1.0f, 0, 0.0f, 1.0f); - Sampler penalised(2, 1.0f, 1.0f, 0, 0.0f, 2.0f); + Sampler baseline( + 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.0f}); + Sampler penalised( + 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 2.0f}); std::vector logits_b = {1.0f, 1.0f}; std::vector logits_p = {1.0f, 1.0f}; std::vector recent = {}; @@ -57,7 +70,7 @@ TEST(SamplerTest, RepetitionPenaltyNoRecentTokensHasNoEffect) { // 4. Min-p truncation: token with very low probability is excluded. TEST(SamplerTest, MinPFiltersTailTokens) { - Sampler s(3, 1.0f, 1.0f, 0, 0.1f, 1.0f); + Sampler s(3, {.temperature = 1.0f, .topp = 1.0f, .min_p = 0.1f}); std::vector logits = {5.0f, -5.0f, -5.0f}; std::vector recent = {}; auto counts = sampleMany(s, logits, recent, 1000); @@ -68,7 +81,7 @@ TEST(SamplerTest, MinPFiltersTailTokens) { // 5. Min-p = 0 disables filtering. TEST(SamplerTest, MinPZeroDisablesFiltering) { - Sampler s(3, 0.0f, 1.0f, 0, 0.0f, 1.0f); + Sampler s(3, {.temperature = 0.0f, .topp = 1.0f}); std::vector logits = {1.0f, -1000.0f, -1000.0f}; std::vector recent = {}; EXPECT_EQ(s.sample(logits.data(), recent), 0); @@ -76,7 +89,7 @@ TEST(SamplerTest, MinPZeroDisablesFiltering) { // 6. Min-p + top-p stacked. TEST(SamplerTest, MinPAndToppStack) { - Sampler s(4, 1.0f, 0.5f, 0, 0.2f, 1.0f); + Sampler s(4, {.temperature = 1.0f, .topp = 0.5f, .min_p = 0.2f}); std::vector logits = {5.0f, 2.0f, -2.0f, -5.0f}; std::vector recent = {}; auto counts = sampleMany(s, logits, recent, 2000); diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.cpp b/packages/react-native-executorch/common/runner/base_llm_runner.cpp index a021040807..7229d64f20 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.cpp +++ b/packages/react-native-executorch/common/runner/base_llm_runner.cpp @@ -56,11 +56,16 @@ Error BaseLLMRunner::load() { ? static_cast(metadata_.at(kMaxContextLen)) : static_cast(metadata_.at(kMaxSeqLen)); } - if (config_.max_new_tokens < 0) - config_.max_new_tokens = - std::min(config_.max_seq_len, config_.max_context_length); config_.enable_dynamic_shape = static_cast(metadata_.at(kEnableDynamicShape)); + if (config_.max_new_tokens < 0) { + // For dynamic-shape PTEs, max_seq_len is the per-call decoder chunk + // size, not the generation budget โ€” use max_context_length instead. + const int32_t seq_cap = config_.enable_dynamic_shape + ? config_.max_context_length + : config_.max_seq_len; + config_.max_new_tokens = std::min(seq_cap, config_.max_context_length); + } config_.enable_kv_cache = static_cast(metadata_.at(kUseKVCache)); eos_ids_ = std::make_unique>(); @@ -149,6 +154,8 @@ void BaseLLMRunner::set_repetition_penalty(float repetition_penalty) noexcept { config_.repetition_penalty = repetition_penalty; } +void BaseLLMRunner::set_topk(int32_t topk) noexcept { config_.topk = topk; } + void BaseLLMRunner::set_count_interval(size_t count_interval) { config_.output_token_batch_size = count_interval; } diff --git a/packages/react-native-executorch/common/runner/base_llm_runner.h b/packages/react-native-executorch/common/runner/base_llm_runner.h index 9710f5ae70..82de49bea3 100644 --- a/packages/react-native-executorch/common/runner/base_llm_runner.h +++ b/packages/react-native-executorch/common/runner/base_llm_runner.h @@ -55,6 +55,7 @@ class BaseLLMRunner { void set_topp(float topp) noexcept; void set_min_p(float min_p) noexcept; void set_repetition_penalty(float repetition_penalty) noexcept; + void set_topk(int32_t topk) noexcept; void set_count_interval(size_t count_interval); void set_time_interval(size_t time_interval); diff --git a/packages/react-native-executorch/common/runner/constants.h b/packages/react-native-executorch/common/runner/constants.h index f1fee23471..368371688a 100644 --- a/packages/react-native-executorch/common/runner/constants.h +++ b/packages/react-native-executorch/common/runner/constants.h @@ -23,8 +23,22 @@ inline constexpr auto kVisionEncoderMethod = "vision_encoder"; inline constexpr auto kAudioEncoderMethod = "audio_encoder"; inline constexpr auto kTokenEmbeddingMethod = "token_embedding"; inline constexpr auto kTextModelMethod = "text_decoder"; - inline constexpr auto numOfAddedBoSTokens = 0; inline constexpr auto numOfAddedEoSTokens = 0; +// Gemma4 +// PLE models only: token id that marks image placeholder slots in input_ids. +// token_embedding run on this id produces the per-layer PLE signal for image +// positions; the inputs_embeds output for those positions is discarded (the +// vision encoder output replaces it). +inline constexpr auto kImagePlaceholderId = "image_placeholder_id"; +// True iff the model exposes a per-layer-embedding (PLE) signal alongside +// inputs_embeds (Gemma4-style). When true, `token_embedding.execute()` +// returns the tuple (inputs_embeds, ple_tok) and the runner must thread +// ple_tok into text_decoder; when false (or absent), token_embedding returns +// inputs_embeds alone. Text-only PTEs that ship a single `forward` method +// omit this key entirely โ€” it is meaningful only for multimodal PTEs that +// expose a separate `token_embedding` method. +inline constexpr auto kHasPLE = "has_ple"; + } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp new file mode 100644 index 0000000000..36227dc966 --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/audio_encoder.cpp @@ -0,0 +1,111 @@ +// common/runner/encoders/audio_encoder.cpp +#include "audio_encoder.h" + +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace executorch::extension::llm { + +using ::executorch::aten::SizesType; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; + +namespace { +constexpr int32_t kSamplingRate = 16e3; +constexpr int32_t kMaxLengthSeconds = 30; +constexpr int32_t kSamplesPerBlock = 7680; +constexpr int64_t kAudioBlockKMin = 1; +constexpr int64_t kAudioBlockKMax = + kSamplingRate * kMaxLengthSeconds / kSamplesPerBlock; +} // namespace + +AudioEncoder::AudioEncoder(::executorch::extension::Module &module) + : module_(&module) {} + +Error AudioEncoder::load() { + if (is_loaded()) { + return Error::Ok; + } + auto method_names_result = module_->method_names(); + if (!method_names_result.ok()) { + return method_names_result.error(); + } + if (method_names_result->count(kAudioEncoderMethod) == 0) { + throw rnexecutorch::RnExecutorchError( + rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "Model does not support audio: 'audio_encoder' method not found. " + "Check that the .pte file matches the declared capabilities."); + } + return module_->load_method(kAudioEncoderMethod); +} + +bool AudioEncoder::is_loaded() const noexcept { + return module_->is_method_loaded(kAudioEncoderMethod); +} + +int32_t AudioEncoder::encoderTokenCount() const noexcept { + return last_token_count_; +} + +Result AudioEncoder::encode(const MultimodalInput &input) { + if (!is_loaded()) { + return Error::InvalidState; + } + if (!input.is_audio()) { + return Error::InvalidArgument; + } + + const auto &wav = input.get_audio(); + ET_CHECK_OR_RETURN_ERROR(!wav.samples.empty(), InvalidArgument, + "AudioEncoder: empty waveform"); + + const int64_t n_valid = static_cast(wav.samples.size()); + const int64_t k_blocks = (n_valid + kSamplesPerBlock - 1) / kSamplesPerBlock; + ET_CHECK_OR_RETURN_ERROR( + k_blocks >= kAudioBlockKMin && k_blocks <= kAudioBlockKMax, + InvalidArgument, + "AudioEncoder: waveform of %lld samples needs k_blocks=%lld.", + static_cast(n_valid), static_cast(k_blocks)); + const int64_t n_padded = k_blocks * kSamplesPerBlock; + + // Own the padded waveform for the lifetime of this call; from_blob below + // borrows without copying. The current export takes + // forward(waveform[1, 7680*k] fp32, num_blocks: int64 scalar) + // โ€” input 1 is a rank-0 Long telling the encoder how many of the K_MAX + // blocks contain real PCM. Passing a 2-d mask here trips "Attempted to + // change tensor rank: old=0, new=2". + padded_wav_.assign(static_cast(n_padded), 0.0f); + std::memcpy(padded_wav_.data(), wav.samples.data(), + static_cast(n_valid) * sizeof(float)); + + valid_samples_scalar_ = n_valid; + + auto wav_tensor = ::executorch::extension::from_blob( + padded_wav_.data(), {1, static_cast(n_padded)}, + ::executorch::aten::ScalarType::Float); + + auto num_blocks_tensor = ::executorch::extension::from_blob( + &valid_samples_scalar_, {}, ::executorch::aten::ScalarType::Long); + + std::vector args = {EValue(*wav_tensor), EValue(*num_blocks_tensor)}; + auto exec_result = ET_UNWRAP(module_->execute(kAudioEncoderMethod, args)); + ET_CHECK_OR_RETURN_ERROR(!exec_result.empty(), InvalidState, + "audio_encoder returned no outputs"); + auto audio_tensor = exec_result[0].toTensor(); + ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState, + "audio_encoder output rank=%zd, expected 3", + audio_tensor.dim()); + last_token_count_ = static_cast(audio_tensor.size(1)); + return exec_result[0]; +} + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/audio_encoder.h b/packages/react-native-executorch/common/runner/encoders/audio_encoder.h new file mode 100644 index 0000000000..9723e4fbd7 --- /dev/null +++ b/packages/react-native-executorch/common/runner/encoders/audio_encoder.h @@ -0,0 +1,40 @@ +// common/runner/encoders/audio_encoder.h +#pragma once + +#include "iencoder.h" +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +// Runs the Gemma4 `audio_encoder` PTE method. +// +// Contract mirrors SpeechToText (Whisper): JS hands in fp32 mono 16 kHz PCM +// via `MultimodalInput::get_audio()`; the PTE owns the log-mel frontend so +// this class just wraps the samples in a `[1, N_samples]` Float tensor and +// executes. Resampling and WAV/MP3 decoding are the caller's responsibility +// (e.g. react-native-audio-api). +class AudioEncoder : public IEncoder { +public: + explicit AudioEncoder(::executorch::extension::Module &module); + + ::executorch::runtime::Error load() override; + bool is_loaded() const noexcept override; + ::executorch::runtime::Result<::executorch::runtime::EValue> + encode(const MultimodalInput &input) override; + // Number of audio embedding tokens produced per encode() call. 0 until first + // encode, since Gemma4's audio_encoder has a dynamic T dim. + int32_t encoderTokenCount() const noexcept override; + +private: + ::executorch::extension::Module *module_; + int32_t last_token_count_ = 0; + std::vector padded_wav_; + int64_t valid_samples_scalar_ = 0; +}; + +} // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp index de3e196c1f..59fee53e11 100644 --- a/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp +++ b/packages/react-native-executorch/common/runner/encoders/vision_encoder.cpp @@ -2,7 +2,6 @@ #include "vision_encoder.h" #include -#include #include #include diff --git a/packages/react-native-executorch/common/runner/irunner.h b/packages/react-native-executorch/common/runner/irunner.h index 54b14c354f..4e5b14444a 100644 --- a/packages/react-native-executorch/common/runner/irunner.h +++ b/packages/react-native-executorch/common/runner/irunner.h @@ -73,6 +73,11 @@ struct GenerationConfig { size_t output_token_batch_size = 10; size_t batch_time_interval_ms = 120; + // Top-k sampling โ€“ keep only the k highest-logit tokens before softmax. + // 0 (default) disables top-k filtering. Stacks with topp: temperature -> + // top-k -> top-p -> softmax -> multinomial. + int32_t topk = 0; + // Enable dynamic input shapes (if implemented) or not // Impacts the prefill phase and causes TextPrefiller to pass all the tokens // at once if set to true. diff --git a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h index 071b193539..8d83c1fa64 100644 --- a/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_decoder_runner.h @@ -14,19 +14,50 @@ #include "text_decoder_runner.h" namespace executorch::extension::llm { +// Supports two PTE contracts, selected per-call from the kHasPLE metadata +// key (mirrors how kEnableDynamicShape etc. are read โ€” queried on demand, +// not cached in a member). Callers that need it multiple times in a hot +// path should snapshot into a local. +// +// * Legacy (has_ple == false): +// token_embedding(ids) -> inputs_embeds +// text_decoder(inputs_embeds, input_pos) +// +// * Gemma-style PLE (has_ple == true): +// token_embedding(ids) -> (inputs_embeds, ple_tok) +// text_decoder(inputs_embeds, ple_tok, input_pos) +// ple_tok carries Gemma4's per-layer PLE signal keyed on input_ids. It's +// computed once in token_embedding and threaded through every decoder call +// so PLE fires at every position (including multimodal placeholder slots). class MultimodalDecoderRunner : public TextDecoderRunner { public: explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager, const GenerationConfig &config) : TextDecoderRunner(module, io_manager, config) {} + bool has_ple() const { + auto r = module_->get(kHasPLE); + if (r.error() != ::executorch::runtime::Error::Ok) { + return false; + } + return r->toScalar().to(); + } + inline ::executorch::runtime::Result<::executorch::aten::Tensor> step(TensorPtr &tokens, int64_t start_pos) override { auto embed_result = module_->execute(kTokenEmbeddingMethod, tokens); if (!embed_result.ok()) { return embed_result.error(); } - return decode((*embed_result)[0], start_pos); + auto &embed_outputs = *embed_result; + if (has_ple()) { + ET_CHECK_MSG(embed_outputs.size() == 2, + "Expected 2 outputs (inputs_embeds, ple_tok) from " + "token_embedding, got %zu", + embed_outputs.size()); + return decode(embed_outputs[0], embed_outputs[1], start_pos); + } + return decode(embed_outputs[0], start_pos); } inline ::executorch::runtime::Result<::executorch::aten::Tensor> @@ -46,6 +77,24 @@ class MultimodalDecoderRunner : public TextDecoderRunner { return outputs[0].toTensor(); } + inline ::executorch::runtime::Result<::executorch::aten::Tensor> + decode(const ::executorch::runtime::EValue &embeddings, + const ::executorch::runtime::EValue &ple_tok, int64_t start_pos) { + auto start_pos_tensor = ::executorch::extension::from_blob( + &start_pos, {1}, ::executorch::aten::ScalarType::Long); + auto outputs_result = module_->execute( + kTextModelMethod, {embeddings, ple_tok, start_pos_tensor}); + if (!outputs_result.ok()) { + return outputs_result.error(); + } + auto &outputs = *outputs_result; + ET_CHECK_MSG(outputs.size() == 1, + "Expected 1 output from text_decoder, got %zu", + outputs.size()); + ET_CHECK_MSG(outputs[0].isTensor(), "text_decoder output is not a tensor"); + return outputs[0].toTensor(); + } + inline ::executorch::runtime::Error load() override { if (is_method_loaded()) { return ::executorch::runtime::Error::Ok; diff --git a/packages/react-native-executorch/common/runner/multimodal_input.h b/packages/react-native-executorch/common/runner/multimodal_input.h index 6b7de35014..b49da0561f 100644 --- a/packages/react-native-executorch/common/runner/multimodal_input.h +++ b/packages/react-native-executorch/common/runner/multimodal_input.h @@ -20,6 +20,10 @@ struct ImagePath { std::string path; }; +struct AudioWaveform { + std::vector samples; +}; + class MultimodalInput { public: explicit MultimodalInput(std::string text) : data_(std::move(text)) {} @@ -27,6 +31,7 @@ class MultimodalInput { : data_(std::move(tokens)) {} explicit MultimodalInput(ImagePath image_path) : data_(std::move(image_path)) {} + explicit MultimodalInput(AudioWaveform audio) : data_(std::move(audio)) {} MultimodalInput(const MultimodalInput &) = default; MultimodalInput &operator=(const MultimodalInput &) = default; @@ -42,6 +47,9 @@ class MultimodalInput { bool is_image() const noexcept { return std::holds_alternative(data_); } + bool is_audio() const noexcept { + return std::holds_alternative(data_); + } const std::string &get_text() const & { return std::get(data_); } const std::vector &get_tokens() const & { @@ -50,9 +58,13 @@ class MultimodalInput { const std::string &get_image_path() const & { return std::get(data_).path; } + const AudioWaveform &get_audio() const & { + return std::get(data_); + } private: - std::variant, ImagePath> data_; + std::variant, ImagePath, AudioWaveform> + data_; }; inline MultimodalInput make_text_input(const std::string &text) noexcept { @@ -64,5 +76,8 @@ inline MultimodalInput make_text_input(std::string &&text) noexcept { inline MultimodalInput make_image_input(std::string path) noexcept { return MultimodalInput(ImagePath{std::move(path)}); } +inline MultimodalInput make_audio_input(std::vector samples) noexcept { + return MultimodalInput(AudioWaveform{std::move(samples)}); +} } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp index 83a1a7f79c..8b04dc39bf 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.cpp @@ -13,6 +13,9 @@ #include "constants.h" #include "util.h" #include +#include +#include +#include namespace executorch::extension::llm { @@ -23,91 +26,390 @@ using ::executorch::runtime::Result; MultimodalPrefiller::MultimodalPrefiller( Module &module, MultimodalDecoderRunner &decoder_runner, - tokenizers::HFTokenizer &tokenizer, IEncoder *image_encoder) + tokenizers::HFTokenizer &tokenizer, + std::unordered_map metadata, IEncoder *image_encoder, + IEncoder *audio_encoder) : module_(&module), decoder_runner_(&decoder_runner), - tokenizer_(&tokenizer), image_encoder_(image_encoder) {} + tokenizer_(&tokenizer), metadata_(metadata), + image_encoder_(image_encoder), audio_encoder_(audio_encoder) {} -Result MultimodalPrefiller::prefill(const MultimodalInput &input, - int64_t &start_pos) { - EValue encoder_output; - std::vector padded_tokens_storage; - TensorPtr sliced_embed_storage; +int64_t MultimodalPrefiller::get_max_seq_len() const { + auto r = module_->get(kMaxSeqLen); + if (r.error() != ::executorch::runtime::Error::Ok) { + return metadata_.at(kMaxSeqLen); + } + return r->toScalar().to(); +} + +int64_t MultimodalPrefiller::get_max_context_len() const { + auto r = module_->get(kMaxContextLen); + if (r.error() != ::executorch::runtime::Error::Ok) { + return metadata_.at(kMaxContextLen) || get_max_seq_len(); + } + return r->toScalar().to(); +} + +bool MultimodalPrefiller::get_enable_dynamic_shape() const { + auto r = module_->get(kEnableDynamicShape); + if (r.error() != ::executorch::runtime::Error::Ok) { + return metadata_.at(kEnableDynamicShape); + } + return r->toScalar().to(); +} +[[nodiscard]] auto MultimodalPrefiller::processMultimodalInput( + const MultimodalInput &input, std::vector &ids, + std::vector &image_slots, + std::vector &audio_slots) { if (input.is_image()) { ET_CHECK_OR_RETURN_ERROR(image_encoder_ != nullptr, InvalidState, "No image encoder registered"); - auto encode_result = image_encoder_->encode(input); - ET_CHECK_OK_OR_RETURN_ERROR(encode_result.error(), "Image encoding failed"); - encoder_output = *encode_result; - - } else if (input.is_text() || input.is_tokens()) { - std::vector tokens; - if (input.is_text()) { - auto encode_result = tokenizer_->encode(input.get_text()); - if (!encode_result.ok()) { - ET_LOG(Error, "Tokenizer encode error %d", - static_cast(encode_result.error())); - return Error::InvalidArgument; - } - tokens = std::move(*encode_result); - } else { - tokens = input.get_tokens(); + const int32_t num_visual = image_encoder_->encoderTokenCount(); + ET_CHECK_OR_RETURN_ERROR(num_visual > 0, InvalidState, + "Image encoder reports 0 visual tokens"); + image_slots.push_back(Types::ImageSlot{&input, + static_cast(ids.size()), + static_cast(num_visual)}); + ids.insert(ids.end(), static_cast(num_visual), 0); + } else if (input.is_audio()) { + ET_CHECK_OR_RETURN_ERROR(audio_encoder_ != nullptr, InvalidState, + "No audio encoder registered"); + auto enc = audio_encoder_->encode(input); + ET_CHECK_OK_OR_RETURN_ERROR(enc.error(), "Audio encoding failed"); + // Snapshot the encoder output NOW โ€” see AudioSlot comment above for + // why the returned EValue's tensor metadata can't survive past the + // next module_->execute(). num_audio and audio_hidden are read from + // the tensor directly rather than from encoderTokenCount() so they + // are guaranteed to reflect THIS encode call. + auto audio_tensor = enc->toTensor(); + ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState, + "audio_encoder output rank=%zd, expected 3", + audio_tensor.dim()); + const int64_t num_audio = static_cast(audio_tensor.size(1)); + const int64_t audio_hidden = static_cast(audio_tensor.size(2)); + ET_CHECK_OR_RETURN_ERROR(num_audio > 0, InvalidState, + "Audio encoder produced 0 tokens"); + std::vector bytes(audio_tensor.nbytes()); + std::memcpy(bytes.data(), audio_tensor.const_data_ptr(), + audio_tensor.nbytes()); + audio_slots.push_back(Types::AudioSlot{ + std::move(bytes), audio_tensor.scalar_type(), + static_cast(ids.size()), num_audio, audio_hidden}); + ids.insert(ids.end(), static_cast(num_audio), 0); + } else if (input.is_text()) { + auto encode_result = tokenizer_->encode(input.get_text()); + if (!encode_result.ok()) { + ET_LOG(Error, "Tokenizer encode error %d", + static_cast(encode_result.error())); + return Error::InvalidArgument; } + std::vector tokens = std::move(*encode_result); + for (auto t : tokens) { + ids.push_back(static_cast(t)); + } + } else if (input.is_tokens()) { + std::vector tokens = input.get_tokens(); + for (auto t : tokens) { + ids.push_back(static_cast(t)); + } + } else { + ET_LOG(Error, "Unsupported MultimodalInput type"); + return Error::NotSupported; + } + return ::executorch::runtime::Error::Ok; +} + +[[nodiscard]] auto MultimodalPrefiller::encodeAudio( + const Types::AudioSlot &slot, const auto hidden, + std::vector &embeds_buf, const size_t embeds_elem_size, + const ::executorch::aten::ScalarType &embeds_dtype) { + ET_CHECK_OR_RETURN_ERROR( + slot.audio_hidden == static_cast(hidden), InvalidState, + "audio encoder hidden %lld != text_embed hidden %lld", + static_cast(slot.audio_hidden), + static_cast(hidden)); - const auto actual_seq_len = static_cast(tokens.size()); + const auto audio_dtype = slot.dtype; + const size_t audio_elems = + static_cast(slot.num_audio) * static_cast(hidden); + const size_t audio_elem_size = + audio_elems > 0 ? slot.bytes.size() / audio_elems : 0; + ET_CHECK_OR_RETURN_ERROR( + audio_elem_size > 0 && audio_elem_size * audio_elems == slot.bytes.size(), + InvalidState, + "audio slot bytes %zu inconsistent with num_audio=%lld hidden=%lld", + slot.bytes.size(), static_cast(slot.num_audio), + static_cast(hidden)); - // The token_embedding PTE has a fixed MAX_SEQ_LEN input buffer. - // Pad with zeros, run embedding, then slice output back to actual length. - int64_t max_seq_len = actual_seq_len; // fallback: no padding needed - auto max_seq_len_result = module_->get(kMaxSeqLen); - if (max_seq_len_result.error() == Error::Ok) { - max_seq_len = max_seq_len_result->toScalar().to(); + uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * + static_cast(hidden) * + embeds_elem_size; + + if (audio_dtype == embeds_dtype) { + std::memcpy(dst, slot.bytes.data(), audio_elems * embeds_elem_size); + } else if (audio_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::Half) { + const float *src = reinterpret_cast(slot.bytes.data()); + auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); + for (size_t i = 0; i < audio_elems; ++i) { + dst_h[i] = ::executorch::aten::Half(src[i]); + } + } else if (audio_dtype == ::executorch::aten::ScalarType::Half && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = + reinterpret_cast(slot.bytes.data()); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < audio_elems; ++i) { + dst_f[i] = static_cast(src[i]); } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidState, + "unsupported audio/text dtype pair: audio=%hhd text=%hhd", + static_cast(audio_dtype), static_cast(embeds_dtype)); + } + return ::executorch::runtime::Error::Ok; +} - padded_tokens_storage.assign(max_seq_len, 0); - std::ranges::copy(tokens, padded_tokens_storage.begin()); +[[nodiscard]] auto MultimodalPrefiller::encodeImages( + const Types::ImageSlot &slot, const auto hidden, + std::vector &embeds_buf, const size_t embeds_elem_size, + const ::executorch::aten::ScalarType &embeds_dtype) { + auto encode_result = image_encoder_->encode(*slot.input); + ET_CHECK_OK_OR_RETURN_ERROR(encode_result.error(), "Image encoding failed"); + auto encoder_output = *encode_result; + auto vision_tensor = encoder_output.toTensor(); - auto text_tensor = ::executorch::extension::from_blob( - padded_tokens_storage.data(), {1, static_cast(max_seq_len)}, - ::executorch::aten::ScalarType::Long); + const auto vision_dtype = vision_tensor.scalar_type(); + const size_t visual_elems = + static_cast(slot.num_visual) * static_cast(hidden); + uint8_t *dst = embeds_buf.data() + static_cast(slot.slot_start) * + static_cast(hidden) * + embeds_elem_size; + if (vision_dtype == embeds_dtype) { + const uint8_t *src = + static_cast(vision_tensor.const_data_ptr()); + std::memcpy(dst, src, visual_elems * embeds_elem_size); + } else if (vision_dtype == ::executorch::aten::ScalarType::Float && + embeds_dtype == ::executorch::aten::ScalarType::Half) { + const float *src = vision_tensor.const_data_ptr(); + auto *dst_h = reinterpret_cast<::executorch::aten::Half *>(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_h[i] = ::executorch::aten::Half(src[i]); + } + } else if (vision_dtype == ::executorch::aten::ScalarType::Half && + embeds_dtype == ::executorch::aten::ScalarType::Float) { + const auto *src = vision_tensor.const_data_ptr<::executorch::aten::Half>(); + auto *dst_f = reinterpret_cast(dst); + for (size_t i = 0; i < visual_elems; ++i) { + dst_f[i] = static_cast(src[i]); + } + } else { + ET_CHECK_OR_RETURN_ERROR( + false, InvalidState, + "unsupported vision/text dtype pair: vision=%hhd text=%hhd", + static_cast(vision_dtype), static_cast(embeds_dtype)); + } + return ::executorch::runtime::Error::Ok; +} - auto embed_result = module_->execute(kTokenEmbeddingMethod, text_tensor); - ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); +[[nodiscard]] auto +MultimodalPrefiller::initializePLE(auto &embed_outputs, auto total_len, + Types::PLEEmbeddings &ple_embeddings) { + auto full_ple_tok = embed_outputs[1].toTensor(); + ple_embeddings.num_layers = static_cast(full_ple_tok.size(2)); + ple_embeddings.ple_dim = static_cast(full_ple_tok.size(3)); + ple_embeddings.ple_tok_dtype = full_ple_tok.scalar_type(); + const size_t total_numel = static_cast(full_ple_tok.numel()); + const size_t total_bytes = full_ple_tok.nbytes(); + ET_CHECK_OR_RETURN_ERROR(total_numel > 0, InvalidState, + "ple_tok has zero elements"); + ple_embeddings.ple_elem_size = total_bytes / total_numel; + const size_t prefix_bytes = static_cast(total_len) * + static_cast(ple_embeddings.num_layers) * + static_cast(ple_embeddings.ple_dim) * + ple_embeddings.ple_elem_size; + ple_embeddings.ple_tok_buf.resize(prefix_bytes); + std::memcpy(ple_embeddings.ple_tok_buf.data(), + full_ple_tok.mutable_data_ptr(), prefix_bytes); + return ::executorch::runtime::Error::Ok; +} - auto full_embed = (*embed_result)[0].toTensor(); - const auto embed_dim = static_cast(full_embed.size(2)); - sliced_embed_storage = ::executorch::extension::from_blob( - full_embed.mutable_data_ptr(), {1, actual_seq_len, embed_dim}, - ::executorch::aten::ScalarType::Float); - encoder_output = EValue(*sliced_embed_storage); +[[nodiscard]] auto MultimodalPrefiller::prefillChunk( + std::vector &last_outs, std::vector &embeds_buf, + auto chunk_start, auto chunk_len, auto hidden, auto embeds_elem_size, + auto embeds_dtype, Types::PLEEmbeddings &ple_embeddings, + std::vector &cache_positions) { + uint8_t *embeds_chunk_ptr = + embeds_buf.data() + static_cast(chunk_start) * + static_cast(hidden) * embeds_elem_size; + auto embeds_chunk = ::executorch::extension::from_blob( + embeds_chunk_ptr, {1, static_cast(chunk_len), hidden}, + embeds_dtype); - } else { - ET_LOG(Error, "Unsupported MultimodalInput type"); - return Error::NotSupported; + TensorPtr ple_chunk; + if (decoder_runner_->has_ple()) { + uint8_t *ple_chunk_ptr = + ple_embeddings.ple_tok_buf.data() + + static_cast(chunk_start) * + static_cast(ple_embeddings.num_layers) * + static_cast(ple_embeddings.ple_dim) * + ple_embeddings.ple_elem_size; + ple_chunk = ::executorch::extension::from_blob( + ple_chunk_ptr, + {1, static_cast(chunk_len), ple_embeddings.num_layers, + ple_embeddings.ple_dim}, + ple_embeddings.ple_tok_dtype); } - // Run text_decoder for prefill. - int64_t seq_len = encoder_output.toTensor().size(1); - if (seq_len == 0) { - ET_LOG(Error, "Encoder returned empty output"); - return Error::InvalidState; + auto pos_chunk = ::executorch::extension::from_blob( + cache_positions.data() + chunk_start, {static_cast(chunk_len)}, + ::executorch::aten::ScalarType::Long); + + auto res = decoder_runner_->has_ple() + ? module_->execute(kTextModelMethod, + {EValue(*embeds_chunk), EValue(*ple_chunk), + EValue(*pos_chunk)}) + : module_->execute(kTextModelMethod, {EValue(*embeds_chunk), + EValue(*pos_chunk)}); + ET_CHECK_OK_OR_RETURN_ERROR(res.error()); + last_outs = std::move(*res); + return ::executorch::runtime::Error::Ok; +} + +Result +MultimodalPrefiller::prefill(const std::vector &inputs, + int64_t &start_pos) { + const bool has_ple = decoder_runner_->has_ple(); + + ET_CHECK_OR_RETURN_ERROR(!inputs.empty(), InvalidArgument, + "prefill: empty input list"); + + // ------------------------------------------------------------ + // * get_max_seq_len โ€” text_decoder S cap. Max prefill chunk length + // (<=get_max_conetxt_len) + // * get_max_context_len โ€” total KV budget. Caps max context length for + // multi-turn conversation. + // ------------------------------------------------------------ + int64_t max_seq_len = get_max_seq_len(); + int64_t max_context_len = get_max_context_len(); + bool enable_dynamic_shape = get_enable_dynamic_shape(); + const int64_t prefill_total_cap = + enable_dynamic_shape ? max_context_len : max_seq_len; + const int64_t decoder_chunk_size = max_seq_len; + + std::vector ids; + ids.reserve(static_cast(prefill_total_cap)); + std::vector image_slots; + std::vector audio_slots; + + for (const auto &input : inputs) { + auto res = processMultimodalInput(input, ids, image_slots, audio_slots); + if (res != ::executorch::runtime::Error::Ok) { + return res; + } + } + + const int64_t total_len = static_cast(ids.size()); + ET_CHECK_OR_RETURN_ERROR(total_len > 0, InvalidArgument, + "prefill produced zero tokens"); + + ET_CHECK_OR_RETURN_ERROR(total_len <= prefill_total_cap, InvalidArgument, + "Prefill length %lld exceeds %s (%lld)", + static_cast(total_len), + enable_dynamic_shape ? "get_max_context_len" + : "get_max_seq_len", + static_cast(prefill_total_cap)); + if (!enable_dynamic_shape) { + ids.resize(static_cast(max_seq_len), 0); + } + + // ------------------------------------------------------------ + // Single token_embedding call over the fused id buffer. + // ------------------------------------------------------------ + const int64_t tok_buf_len = static_cast(ids.size()); + auto token_tensor = ::executorch::extension::from_blob( + ids.data(), {1, static_cast(tok_buf_len)}, + ::executorch::aten::ScalarType::Long); + + auto embed_result = module_->execute(kTokenEmbeddingMethod, token_tensor); + ET_CHECK_OK_OR_RETURN_ERROR(embed_result.error()); + auto &embed_outputs = *embed_result; + + auto full_embed = embed_outputs[0].toTensor(); + const auto hidden = static_cast(full_embed.size(2)); + + // Own the embeds for the live prefix โ€” subsequent vision_encoder.execute + // calls may reuse the token_embedding output buffer in the runtime. + const ::executorch::aten::ScalarType embeds_dtype = full_embed.scalar_type(); + const size_t embeds_total_numel = static_cast(full_embed.numel()); + ET_CHECK_OR_RETURN_ERROR(embeds_total_numel > 0, InvalidState, + "token_embedding returned zero elements"); + const size_t embeds_elem_size = full_embed.nbytes() / embeds_total_numel; + const size_t embeds_prefix_bytes = static_cast(total_len) * + static_cast(hidden) * + embeds_elem_size; + std::vector embeds_buf(embeds_prefix_bytes); + std::memcpy(embeds_buf.data(), full_embed.mutable_data_ptr(), + embeds_prefix_bytes); + + // ------------------------------------------------------------ + // Pass 2: encode images and splice their outputs into embeds_buf. + // ------------------------------------------------------------ + for (const auto &slot : image_slots) { + auto res = + encodeImages(slot, hidden, embeds_buf, embeds_elem_size, embeds_dtype); + if (res != ::executorch::runtime::Error::Ok) { + return res; + } } - std::vector cache_positions; - auto cache_pos_result = populate_start_pos_or_cache_position( - module_, start_pos, cache_positions, seq_len, kTextModelMethod); - ET_CHECK_OK_OR_RETURN_ERROR(cache_pos_result.error()); + // ------------------------------------------------------------ + // Pass 2b: splice encoded audio tokens into embeds_buf. Reads from the + // byte snapshot taken at encode time so post-encode execute() calls can't + // invalidate slot state. Same dtype-conversion matrix as vision. + // ------------------------------------------------------------ + for (auto &slot : audio_slots) { + auto res = + encodeAudio(slot, hidden, embeds_buf, embeds_elem_size, embeds_dtype); + if (res != ::executorch::runtime::Error::Ok) { + return res; + } + } - auto prefill_result = - module_->execute(kTextModelMethod, {encoder_output, *cache_pos_result}); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_result.error()); + Types::PLEEmbeddings ple_embeddings; + if (has_ple) { + auto res = initializePLE(embed_outputs, total_len, ple_embeddings); + if (res != ::executorch::runtime::Error::Ok) { + return res; + } + } - auto &prefill_outputs = *prefill_result; - ET_CHECK_OR_RETURN_ERROR(!prefill_outputs.empty(), InvalidState, + std::vector last_outs; + const int64_t chunk_cap = + decoder_chunk_size > 0 ? decoder_chunk_size : total_len; + std::vector cache_positions(static_cast(total_len)); + for (int64_t i = 0; i < total_len; ++i) { + cache_positions[static_cast(i)] = start_pos + i; + } + const int64_t num_chunks = (total_len + chunk_cap - 1) / chunk_cap; + for (int64_t ci = 0; ci < num_chunks; ++ci) { + const int64_t chunk_start = ci * chunk_cap; + const int64_t chunk_end = std::min(chunk_start + chunk_cap, total_len); + const int64_t chunk_len = chunk_end - chunk_start; + auto res = prefillChunk(last_outs, embeds_buf, chunk_start, chunk_len, + hidden, embeds_elem_size, embeds_dtype, + ple_embeddings, cache_positions); + if (res != ::executorch::runtime::Error::Ok) { + return res; + } + } + + ET_CHECK_OR_RETURN_ERROR(!last_outs.empty(), InvalidState, "text_decoder returned no outputs during prefill"); - auto logits = prefill_outputs[0].toTensor(); - start_pos += seq_len; + auto logits = last_outs[0].toTensor(); + start_pos += total_len; return static_cast(decoder_runner_->logits_to_token(logits)); } @@ -127,6 +429,9 @@ Error MultimodalPrefiller::load() { if (methods.find(kVisionEncoderMethod) != methods.end()) { ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kVisionEncoderMethod)); } + if (methods.find(kAudioEncoderMethod) != methods.end()) { + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod)); + } return Error::Ok; } @@ -140,8 +445,13 @@ bool MultimodalPrefiller::is_method_loaded() { return false; } const auto &methods = *methods_res; - if (methods.find(kVisionEncoderMethod) != methods.end()) { - return module_->is_method_loaded(kVisionEncoderMethod); + if (methods.find(kVisionEncoderMethod) != methods.end() && + !module_->is_method_loaded(kVisionEncoderMethod)) { + return false; + } + if (methods.find(kAudioEncoderMethod) != methods.end() && + !module_->is_method_loaded(kAudioEncoderMethod)) { + return false; } return true; } diff --git a/packages/react-native-executorch/common/runner/multimodal_prefiller.h b/packages/react-native-executorch/common/runner/multimodal_prefiller.h index d9b5a9bf5c..05037d88c8 100644 --- a/packages/react-native-executorch/common/runner/multimodal_prefiller.h +++ b/packages/react-native-executorch/common/runner/multimodal_prefiller.h @@ -18,26 +18,77 @@ namespace executorch::extension::llm { +namespace Types { +struct ImageSlot { + const MultimodalInput *input; // non-owning, valid for duration of call + int64_t slot_start; + int64_t num_visual; +}; + +struct AudioSlot { + std::vector bytes; + ::executorch::aten::ScalarType dtype; + int64_t slot_start; + int64_t num_audio; + int64_t audio_hidden; +}; + +struct PLEEmbeddings { + std::vector ple_tok_buf; + aten::SizesType num_layers = 0; + aten::SizesType ple_dim = 0; + size_t ple_elem_size = 0; + ::executorch::aten::ScalarType ple_tok_dtype = + ::executorch::aten::ScalarType::Half; +}; +} // namespace Types + class MultimodalPrefiller { public: - explicit MultimodalPrefiller(Module &module, - MultimodalDecoderRunner &decoder_runner, - tokenizers::HFTokenizer &tokenizer, - IEncoder *image_encoder = nullptr); + explicit MultimodalPrefiller( + Module &module, MultimodalDecoderRunner &decoder_runner, + tokenizers::HFTokenizer &tokenizer, + std::unordered_map metadata, + IEncoder *image_encoder = nullptr, IEncoder *audio_encoder = nullptr); // Prefill one input segment. Updates start_pos in-place. // Returns the first predicted token after this segment. - ::executorch::runtime::Result prefill(const MultimodalInput &input, - int64_t &start_pos); + ::executorch::runtime::Result + prefill(const std::vector &inputs, int64_t &start_pos); + auto processMultimodalInput(const MultimodalInput &input, + std::vector &ids, + std::vector &image_slots, + std::vector &audio_slots); ::executorch::runtime::Error load(); bool is_method_loaded(); + int64_t get_max_seq_len() const; + int64_t get_max_context_len() const; + bool get_enable_dynamic_shape() const; private: + auto encodeImages(const Types::ImageSlot &slot, const auto hidden, + std::vector &embeds_buf, + const size_t embeds_elem_size, + const ::executorch::aten::ScalarType &embeds_dtype); + auto encodeAudio(const Types::AudioSlot &slot, const auto hidden, + std::vector &embeds_buf, + const size_t embeds_elem_size, + const ::executorch::aten::ScalarType &embeds_dtype); + auto prefillChunk(std::vector<::executorch::runtime::EValue> &last_outs, + std::vector &embeds_buf, auto chunk_start, + auto chunk_len, auto hidden, auto embeds_elem_size, + auto embeds_dtype, Types::PLEEmbeddings &ple_embeddings, + std::vector &cache_positions); + auto initializePLE(auto &embed_outputs, auto total_len, + Types::PLEEmbeddings &ple_embeddings); + Module *module_; MultimodalDecoderRunner *decoder_runner_; tokenizers::HFTokenizer *tokenizer_; + std::unordered_map metadata_; IEncoder *image_encoder_; + IEncoder *audio_encoder_; }; } // namespace executorch::extension::llm diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.cpp b/packages/react-native-executorch/common/runner/multimodal_runner.cpp index 767fef9f38..084a7ef191 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.cpp +++ b/packages/react-native-executorch/common/runner/multimodal_runner.cpp @@ -3,7 +3,6 @@ #include "constants.h" #include "util.h" #include -#include namespace executorch::extension::llm { @@ -54,8 +53,14 @@ Error MultimodalRunner::load_subcomponents() { if (enc_it != encoders_.end()) { image_encoder = enc_it->second.get(); } + IEncoder *audio_encoder = nullptr; + auto aud_it = encoders_.find(MultimodalType::Audio); + if (aud_it != encoders_.end()) { + audio_encoder = aud_it->second.get(); + } mm_prefiller_ = std::make_unique( - *module_, *mm_decoder_runner_, *tokenizer_, image_encoder); + *module_, *mm_decoder_runner_, *tokenizer_, metadata_, image_encoder, + audio_encoder); mm_token_generator_ = std::make_unique( tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true, std::move(eos_ids_), stats_ptr, config_); @@ -78,22 +83,24 @@ Error MultimodalRunner::generate_internal( } stats_.inference_start_ms = time_in_ms(); - - uint64_t prefill_next_token = 0; - for (const auto &input : inputs) { - auto prefill_result = mm_prefiller_->prefill(input, pos_); - if (!prefill_result.ok()) - return prefill_result.error(); - prefill_next_token = prefill_result.get(); - } + auto prefill_result = mm_prefiller_->prefill(inputs, pos_); + if (!prefill_result.ok()) + return prefill_result.error(); + uint64_t prefill_next_token = prefill_result.get(); stats_.first_token_ms = time_in_ms(); stats_.prompt_eval_end_ms = time_in_ms(); stats_.num_prompt_tokens = pos_; + // For dynamic-shape PTEs (Gemma4 iter*), get_max_seq_len is the per-call + // decoder chunk size (e.g. 128) and the true generation budget lives in + // get_max_context_len. Mirrors text_runner.cpp:95-97. + const int32_t seq_cap = config_.enable_dynamic_shape + ? config_.max_context_length + : config_.max_seq_len; int32_t resolved_max_new = resolve_max_new_tokens( - static_cast(pos_), config_.max_seq_len, - config_.max_context_length, config_.max_new_tokens); + static_cast(pos_), seq_cap, config_.max_context_length, + config_.max_new_tokens); std::vector seed_tokens = {prefill_next_token}; auto wrapped_callback = [&](const std::string &piece) { diff --git a/packages/react-native-executorch/common/runner/multimodal_runner.h b/packages/react-native-executorch/common/runner/multimodal_runner.h index d24e0b40c2..c6180c54f0 100644 --- a/packages/react-native-executorch/common/runner/multimodal_runner.h +++ b/packages/react-native-executorch/common/runner/multimodal_runner.h @@ -10,7 +10,7 @@ namespace executorch::extension::llm { -enum class MultimodalType { Image }; +enum class MultimodalType { Image, Audio }; class MultimodalRunner : public BaseLLMRunner { public: diff --git a/packages/react-native-executorch/common/runner/sampler.cpp b/packages/react-native-executorch/common/runner/sampler.cpp index 26c75d4dd5..250d6a83ef 100644 --- a/packages/react-native-executorch/common/runner/sampler.cpp +++ b/packages/react-native-executorch/common/runner/sampler.cpp @@ -35,6 +35,7 @@ #include "sampler.h" #include #include +#include #include namespace executorch { @@ -46,7 +47,7 @@ template int32_t Sampler::sample_argmax(T *probabilities) { // return the index that has the highest probability int max_i = 0; T max_p = probabilities[0]; - for (int i = 1; i < vocab_size_; i++) { + for (size_t i = 1; i < vocab_size_; i++) { if (probabilities[i] > max_p) { max_i = i; max_p = probabilities[i]; @@ -60,7 +61,7 @@ int32_t Sampler::sample_mult(T *probabilities, float coin) { // sample index from probabilities (they must sum to 1!) // coin is a random number in [0, 1), usually from random_f32() T cdf = 0.0; - for (int i = 0; i < vocab_size_; i++) { + for (size_t i = 0; i < vocab_size_; i++) { cdf += probabilities[i]; if (coin < cdf) { return i; @@ -84,7 +85,7 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) { std::make_unique[]>(vocab_size_); const float cutoff = (1.0f - topp_) / (n - 1); - for (int i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) { if (probabilities[i] >= cutoff) { probindex[n0].index = i; probindex[n0].prob = probabilities[i]; @@ -92,61 +93,138 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) { } } - auto compare = [](const ProbIndex &a, const ProbIndex &b) { - return a.prob > b.prob; - }; - std::sort(probindex.get(), probindex.get() + n0, compare); + std::sort(probindex.get(), probindex.get() + n0, + [](const ProbIndex &a, const ProbIndex &b) { + return a.prob > b.prob; + }); // truncate the list where cumulative probability exceeds topp T cumulative_prob = 0; - int last_idx = n0 - 1; // in case of rounding errors consider all elements - for (int i = 0; i < n0; i++) { + int last_idx = n0 - 1; + for (size_t i = 0; i < n0; i++) { cumulative_prob += probindex[i].prob; - if (cumulative_prob > topp_) { + if (static_cast(cumulative_prob) > topp_) { last_idx = i; - break; // we've exceeded topp by including last_idx + break; } } // sample from the truncated list - const T &r = coin * cumulative_prob; + float r = coin * static_cast(cumulative_prob); T cdf = 0; - for (int i = 0; i <= last_idx; i++) { + for (size_t i = 0; i <= last_idx; i++) { cdf += probindex[i].prob; - if (r < cdf) { + if (r < static_cast(cdf)) { return probindex[i].index; } } - return probindex[last_idx].index; // in case of rounding errors + return probindex[last_idx].index; } -Sampler::Sampler(int32_t vocab_size, float temperature, float topp, - unsigned long long rng_seed, float min_p, - float repetition_penalty) +// Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary +// are kept (matches HuggingFace TopKLogitsWarper). +template void Sampler::mask_topk(T *logits) { + if (topk_ <= 0 || topk_ >= vocab_size_) { + return; + } + // Partial-select the (topk_-th largest) threshold using nth_element on a + // copy of logits; O(n) average. + std::vector scratch(logits, logits + vocab_size_); + std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1), + scratch.end(), std::greater()); + const T threshold = scratch[topk_ - 1]; + constexpr T neg_inf = std::numeric_limits::lowest(); + for (size_t i = 0; i < vocab_size_; i++) { + if (logits[i] < threshold) { + logits[i] = neg_inf; + } + } +} + +// Mask logits whose softmax-prob falls outside the top-p nucleus to -inf. +// Keeps the token that crosses the threshold (HuggingFace convention). +template void Sampler::mask_topp(T *logits) { + if (topp_ <= 0.0f || topp_ >= 1.0f) { + return; + } + // Softmax into a scratch probs[] (do not mutate logits yet). + T max_val = logits[0]; + for (size_t i = 1; i < vocab_size_; i++) { + if (logits[i] > max_val) { + max_val = logits[i]; + } + } + std::unique_ptr[]> probindex = + std::make_unique[]>(vocab_size_); + T sum = 0; + for (size_t i = 0; i < vocab_size_; i++) { + T e = static_cast(std::expf(static_cast(logits[i] - max_val))); + probindex[i].prob = e; + probindex[i].index = i; + sum += e; + } + if (sum <= T(0)) { + return; + } + for (size_t i = 0; i < vocab_size_; i++) { + probindex[i].prob /= sum; + } + std::sort(probindex.get(), probindex.get() + vocab_size_, + [](const ProbIndex &a, const ProbIndex &b) { + return a.prob > b.prob; + }); + + // Find the smallest prefix whose cumulative probability >= topp_. + T cumulative = 0; + int last_idx = vocab_size_ - 1; + for (size_t i = 0; i < vocab_size_; i++) { + cumulative += probindex[i].prob; + if (static_cast(cumulative) >= topp_) { + last_idx = i; + break; + } + } + // Mark kept indices, then -inf the rest. + std::vector keep(vocab_size_, false); + for (size_t i = 0; i <= last_idx; i++) { + keep[probindex[i].index] = true; + } + constexpr T neg_inf = std::numeric_limits::lowest(); + for (size_t i = 0; i < vocab_size_; i++) { + if (!keep[i]) { + logits[i] = neg_inf; + } + } +} + +Sampler::Sampler(int32_t vocab_size, GenerationConfig config, + unsigned long long rng_seed) : vocab_size_(vocab_size), - inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f), - topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty), + inv_temperature_( + (config.temperature != 0.0f) ? (1.0f / config.temperature) : 0.0f), + topp_(config.topp), min_p_(config.min_p), + repetition_penalty_(config.repetition_penalty), topk_(config.topk), rng_state_(rng_seed) {} -Sampler::Sampler(int vocab_size, float temperature, float topp) - : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {} +Sampler::Sampler(int32_t vocab_size, GenerationConfig config) + : Sampler(vocab_size, config, std::time(nullptr)) {} template static void softmax(T *x, int size) { // find max value (for numerical stability) T max_val = x[0]; - for (int i = 1; i < size; i++) { + for (size_t i = 1; i < size; i++) { if (x[i] > max_val) { max_val = x[i]; } } // exp and sum T sum = 0; - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { x[i] = expf(x[i] - max_val); sum += x[i]; } // normalize - for (int i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { x[i] /= sum; } } @@ -175,20 +253,18 @@ int32_t Sampler::sample(T *logits, const std::vector &recent_tokens) { apply_repetition_penalty(logits, vocab_size_, recent_tokens); // 2. apply the temperature to the logits apply_temperature(logits, vocab_size_); - // 3. apply softmax to the logits to get the probabilities for next token + // 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass) + mask_topk(logits); + // 4. mask out logits outside top-p by rank (pre-softmax) + mask_topp(logits); + // 5. apply softmax to the logits to get the probabilities for next token softmax(logits, vocab_size_); - // 4. apply min_p truncation + // 6. apply min_p truncation apply_min_p(logits, vocab_size_); // flip a (float) coin (this is our source of entropy for sampling) float coin = random_f32(&rng_state_); - // 5. we sample from this distribution to get the next token - if (topp_ <= 0 || topp_ >= 1) { - // simply sample from the predicted probability distribution - next = sample_mult(logits, coin); - } else { - // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(logits, coin); - } + // 7. we sample from this distribution to get the next token + next = sample_mult(logits, coin); } return next; } diff --git a/packages/react-native-executorch/common/runner/sampler.h b/packages/react-native-executorch/common/runner/sampler.h index 16811297ef..6af1a5a487 100644 --- a/packages/react-native-executorch/common/runner/sampler.h +++ b/packages/react-native-executorch/common/runner/sampler.h @@ -8,6 +8,7 @@ #pragma once +#include "runner/irunner.h" #include #include #include @@ -28,6 +29,7 @@ namespace executorch { namespace extension { namespace llm { // A simple llama2 sampler. +struct GenerationConfig; inline constexpr auto kTopp = 0.9f; @@ -38,11 +40,13 @@ template struct ProbIndex { class Sampler { public: - Sampler(int32_t vocab_size, float temperature, float topp, - unsigned long long rng_seed, float min_p = 0.0f, - float repetition_penalty = 1.0f); - - Sampler(int32_t vocab_size, float temperature, float topp); + // topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p. + // Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask + // -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses + // to greedy; pass topk = 0 to keep full-vocab temperature sampling. + Sampler(int32_t vocab_size, GenerationConfig config, + unsigned long long rng_seed); + Sampler(int32_t vocab_size, GenerationConfig config); template int32_t sample(T *logits); @@ -53,6 +57,9 @@ class Sampler { template int32_t sample_topp(T *probabilities, float coin); template int32_t sample_mult(T *probabilities, float coin); template int32_t sample_argmax(T *probabilities); + // In-place logit warpers: set excluded indices to -inf. + template void mask_topk(T *logits); + template void mask_topp(T *logits); template inline void apply_temperature(T *logits, int32_t vocab_size) { @@ -110,6 +117,7 @@ class Sampler { float topp_; float min_p_; float repetition_penalty_; + int32_t topk_; unsigned long long rng_state_; }; diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp index e67d3e41fb..77770e3418 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.cpp @@ -31,7 +31,6 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager, // outer loop (call site) is responsible for managing state. ::executorch::runtime::Result TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) { - // ET_LOG(Info, "Input token %" PRIu64, input_token); auto method_meta_result = module_->method_meta("forward"); if (!method_meta_result.ok()) { return method_meta_result.error(); @@ -102,9 +101,7 @@ int32_t TextDecoderRunner::logits_to_token( auto num_tokens = logits_tensor.size(1); logits += (num_tokens - 1) * vocab_size; } - Sampler sampler(vocab_size, config_.temperature, config_.topp, - static_cast(std::time(nullptr)), - config_.min_p, config_.repetition_penalty); + Sampler sampler(vocab_size, config_); result = sampler.sample(logits, recent_tokens); }); return result; diff --git a/packages/react-native-executorch/common/runner/text_decoder_runner.h b/packages/react-native-executorch/common/runner/text_decoder_runner.h index bffc254bd6..d3aa229cd0 100644 --- a/packages/react-native-executorch/common/runner/text_decoder_runner.h +++ b/packages/react-native-executorch/common/runner/text_decoder_runner.h @@ -10,6 +10,7 @@ #pragma once +#include "constants.h" #include "io_manager.h" #include "sampler.h" @@ -40,8 +41,8 @@ class TextDecoderRunner { step(TensorPtr &input, int64_t start_pos); /** - * Load the Module for text decode purpose. - * @return The error code. + * Load the Module for text decode purpose. Loads the dynamic-shape `forward` + * method used for both prefill and decode. */ virtual ::executorch::runtime::Error load() { return module_->load_method("forward"); diff --git a/packages/react-native-executorch/common/runner/text_prefiller.cpp b/packages/react-native-executorch/common/runner/text_prefiller.cpp index dc961158b7..8325efae8f 100644 --- a/packages/react-native-executorch/common/runner/text_prefiller.cpp +++ b/packages/react-native-executorch/common/runner/text_prefiller.cpp @@ -33,15 +33,16 @@ TextPrefiller::prefill(std::vector &prompt_tokens, // Check if we need to chunk the prompt tokens int32_t num_prompt_tokens = prompt_tokens.size(); + const int32_t chunk_size = static_cast(max_seq_len_); - // If prompt tokens exceed max_seq_len_, we need to chunk them - if (num_prompt_tokens > max_seq_len_) { + // If prompt tokens exceed chunk_size, we need to chunk them + if (num_prompt_tokens > chunk_size) { uint64_t cur_token = 0; int num_tokens_to_process = 0; while (num_tokens_to_process < num_prompt_tokens) { - auto num_tokens_to_prefill_with = std::min( - num_prompt_tokens - num_tokens_to_process, max_seq_len_); + auto num_tokens_to_prefill_with = + std::min(num_prompt_tokens - num_tokens_to_process, chunk_size); std::vector prompt_tokens_to_process( num_tokens_to_prefill_with); @@ -75,7 +76,6 @@ TextPrefiller::prefill_chunk(std::vector &prompt_tokens, // store the token uint64_t cur_token; if (enable_parallel_prefill_ || !use_kv_cache_) { - // initialize tensor wrappers auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, executorch::aten::ScalarType::Long); diff --git a/packages/react-native-executorch/common/runner/text_runner.cpp b/packages/react-native-executorch/common/runner/text_runner.cpp index 5a75e00b4a..348f0bfb82 100644 --- a/packages/react-native-executorch/common/runner/text_runner.cpp +++ b/packages/react-native-executorch/common/runner/text_runner.cpp @@ -65,6 +65,10 @@ Error TextRunner::generate_internal( stats_.inference_start_ms = time_in_ms(); + // Multi-turn: JS re-renders the full chat history each call, so reset KV + // position to 0 and re-prefill from scratch. + pos_ = 0; + int64_t context_len_left = static_cast(config_.max_context_length) - pos_; @@ -79,16 +83,23 @@ Error TextRunner::generate_internal( std::vector prompt_tokens = encodeResult.get(); int num_prompt_tokens = prompt_tokens.size(); + // For dynamic-shape PTEs (Gemma4 iter*), get_max_seq_len is the per-call + // decoder chunk size (e.g. 128) and the true generation budget lives in + // get_max_context_len. Static-shape PTEs set both equal, so this collapses + // to the old behavior. Mirrors multimodal_prefiller.cpp:96. + const int32_t seq_cap = config_.enable_dynamic_shape + ? config_.max_context_length + : config_.max_seq_len; + ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument, "Expected at least 1 prompt token"); - ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < config_.max_seq_len, - InvalidArgument, - "num_prompt_tokens %d >= max_seq_len %" PRId32, - num_prompt_tokens, config_.max_seq_len); + ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < seq_cap, InvalidArgument, + "num_prompt_tokens %d >= seq cap %" PRId32, + num_prompt_tokens, seq_cap); int32_t max_new_tokens = resolve_max_new_tokens( - num_prompt_tokens, config_.max_seq_len, - static_cast(context_len_left), config_.max_new_tokens); + num_prompt_tokens, seq_cap, static_cast(context_len_left), + config_.max_new_tokens); ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument, "Max new tokens %d is <= 0", max_new_tokens); diff --git a/packages/react-native-executorch/common/runner/text_token_generator.h b/packages/react-native-executorch/common/runner/text_token_generator.h index 7ecf6177a9..13f53bd2e4 100644 --- a/packages/react-native-executorch/common/runner/text_token_generator.h +++ b/packages/react-native-executorch/common/runner/text_token_generator.h @@ -100,8 +100,8 @@ class TextTokenGenerator { prev_token = cur_token; stats_->on_sampling_begin(); - cur_token = - text_decoder_runner_->logits_to_token(logits_tensor, generated_tokens); + cur_token = text_decoder_runner_->logits_to_token(logits_tensor, + generated_tokens); stats_->on_sampling_end(); pos++; @@ -142,8 +142,7 @@ class TextTokenGenerator { const auto eos_reached = eos_ids_->contains(cur_token); if (!cache_decoded.ends_with("๏ฟฝ") && - (countIntervalElapsed || timeIntervalElapsed || should_stop_ || - eos_reached)) { + (countIntervalElapsed || timeIntervalElapsed || should_stop_ || eos_reached)) { token_callback(cache_decoded); token_cache.clear(); timestamp_ = std::chrono::high_resolution_clock::now(); @@ -152,7 +151,6 @@ class TextTokenGenerator { if (should_stop_) { break; } - // data-dependent terminating condition: we have n_eos_ number of EOS if (eos_ids_->find(cur_token) != eos_ids_->end()) { printf("\n"); diff --git a/packages/react-native-executorch/common/runner/util.h b/packages/react-native-executorch/common/runner/util.h index 640b96319f..b1e707034b 100644 --- a/packages/react-native-executorch/common/runner/util.h +++ b/packages/react-native-executorch/common/runner/util.h @@ -8,7 +8,6 @@ #pragma once #include "constants.h" -#include "text_prefiller.h" #include #include #include diff --git a/packages/react-native-executorch/src/constants/llmDefaults.ts b/packages/react-native-executorch/src/constants/llmDefaults.ts index a27a2f7a4f..77a60fe311 100644 --- a/packages/react-native-executorch/src/constants/llmDefaults.ts +++ b/packages/react-native-executorch/src/constants/llmDefaults.ts @@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy'; * @category Utilities - LLM */ export const DEFAULT_SYSTEM_PROMPT = - "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text."; + "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance"; /** * Generates a default structured output prompt based on the provided JSON schema. diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts index 9c9da9c420..bd04a37883 100644 --- a/packages/react-native-executorch/src/constants/modelRegistry.ts +++ b/packages/react-native-executorch/src/constants/modelRegistry.ts @@ -496,10 +496,12 @@ export const models = { M.LFM2_5_1_2B_INSTRUCT_QUANTIZED ), bielik_v3_0_1_5b: pair(M.BIELIK_V3_0_1_5B, M.BIELIK_V3_0_1_5B_QUANTIZED), + gemma4_e2b: base(M.GEMMA4_E2B), // Multimodal LLMs โ€” same hook/module as plain LLMs, listed here so users // pick a model by capability ("LLM") rather than by modality. lfm2_5_vl_1_6b: base(M.LFM2_5_VL_1_6B_QUANTIZED), lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED), + gemma4_e2b_multimodal: base(M.GEMMA4_E2B_MM), }, classification: { efficientnet_v2_s: variant(EFFICIENTNET_V2_S_VARIANTS), diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index f02abd8e32..7ed7f7d11a 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -124,6 +124,41 @@ export const QWEN3_0_6B_QUANTIZED = { generationConfig: QWEN3_GENERATION_CONFIG, } as const; +// GEMMA 4 +const GEMMA4_E2B_XNNPACK = `${URL_PREFIX}-gemma-4/${PREVIOUS_VERSION_TAG}/e2b/xnnpack/gemma_4_e2b_xnnpack_8da4w.pte`; +const GEMMA4_E2B_VULKAN = `${URL_PREFIX}-gemma-4/${PREVIOUS_VERSION_TAG}/e2b/vulkan/gemma_4_e2b_vulkan_8da4w.pte`; +const GEMMA4_E2B_XNNPACK_MM = `${URL_PREFIX}-gemma-4-multimodal/${PREVIOUS_VERSION_TAG}/e2b/xnnpack/gemma_4_e2b_xnnpack_8da4w.pte`; +const GEMMA4_E2B_VULKAN_MM = `${URL_PREFIX}-gemma-4-multimodal/${PREVIOUS_VERSION_TAG}/e2b/vulkan/gemma_4_e2b_vulkan_8da4w.pte`; +const GEMMA4_TOKENIZER = `${URL_PREFIX}-gemma-4/${PREVIOUS_VERSION_TAG}/e2b/tokenizer.json`; +const GEMMA4_TOKENIZER_CONFIG = `${URL_PREFIX}-gemma-4/${PREVIOUS_VERSION_TAG}/e2b/tokenizer_config.json`; + +/** + * @category Models - LLM + */ +export const GEMMA4_E2B = { + modelName: 'gemma4-e2b', + modelSource: + Platform.OS === `android` ? GEMMA4_E2B_VULKAN : GEMMA4_E2B_XNNPACK, + tokenizerSource: GEMMA4_TOKENIZER, + tokenizerConfigSource: GEMMA4_TOKENIZER_CONFIG, +} as const; + +/** + * @category Models - VLM + */ +export const GEMMA4_E2B_MM = { + modelName: 'gemma4-e2b-multimodal', + modelSource: + Platform.OS === `android` ? GEMMA4_E2B_VULKAN_MM : GEMMA4_E2B_XNNPACK_MM, + tokenizerSource: GEMMA4_TOKENIZER, + tokenizerConfigSource: GEMMA4_TOKENIZER_CONFIG, + capabilities: ['vision', 'audio'], + audioConfig: { + samplesPerBlock: 7680, + tokensPerBlock: 12, + }, +} as const; + /** * @category Models - LLM */ diff --git a/packages/react-native-executorch/src/controllers/LLMController.ts b/packages/react-native-executorch/src/controllers/LLMController.ts index bceca47a56..4385ca909b 100644 --- a/packages/react-native-executorch/src/controllers/LLMController.ts +++ b/packages/react-native-executorch/src/controllers/LLMController.ts @@ -1,11 +1,11 @@ -import { ResourceSource } from '../types/common'; import { ResourceFetcher } from '../utils/ResourceFetcher'; import { Template } from '@huggingface/jinja'; import { DEFAULT_CHAT_CONFIG } from '../constants/llmDefaults'; import { + AudioConfig, ChatConfig, GenerationConfig, - LLMCapability, + LLMModel, LLMTool, Message, SPECIAL_TOKENS, @@ -30,6 +30,7 @@ export class LLMController { private messageHistoryCallback: (messageHistory: Message[]) => void; private isReadyCallback: (isReady: boolean) => void; private isGeneratingCallback: (isGenerating: boolean) => void; + private audioConfig: AudioConfig | undefined; constructor({ tokenCallback, @@ -72,18 +73,10 @@ export class LLMController { } public async load({ - modelSource, - tokenizerSource, - tokenizerConfigSource, - capabilities, - defaultGenerationConfig, + model, onDownloadProgressCallback, }: { - modelSource: ResourceSource; - tokenizerSource: ResourceSource; - tokenizerConfigSource: ResourceSource; - capabilities?: readonly LLMCapability[]; - defaultGenerationConfig?: GenerationConfig; + model: LLMModel; onDownloadProgressCallback?: (downloadProgress: number) => void; }) { // reset inner state when loading new model @@ -94,13 +87,13 @@ export class LLMController { try { const tokenizersPromise = ResourceFetcher.fetch( undefined, - tokenizerSource, - tokenizerConfigSource + model.tokenizerSource, + model.tokenizerConfigSource ); const modelPromise = ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource ); const [tokenizersResults, modelResult] = await Promise.all([ @@ -124,16 +117,18 @@ export class LLMController { this.nativeModule.unload(); } + this.audioConfig = model.audioConfig; + this.nativeModule = await global.loadLLM( modelPath, tokenizerPath, - capabilities ?? [] + model.capabilities ?? [] ); - if (defaultGenerationConfig) { + if (model.generationConfig) { // Apply model-specific recommended sampling defaults before flipping // isReady so callers that react to it see the right config on first // send. User-provided `configure()` calls still override these. - this.applyGenerationConfig(defaultGenerationConfig); + this.applyGenerationConfig(model.generationConfig); } this.isReadyCallback(true); this.onToken = (data: string) => { @@ -236,6 +231,17 @@ export class LLMController { return token; } + private getAudioToken(): string { + const token = this.tokenizerConfig.audio_token; + if (!token) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidConfig, + "Tokenizer config is missing 'audio_token'. Audio-capable models require tokenizerConfigSource with an 'audio_token' field." + ); + } + return token; + } + private filterSpecialTokens(text: string): string { let filtered = text; if ( @@ -244,6 +250,12 @@ export class LLMController { ) { filtered = filtered.replaceAll(this.tokenizerConfig.eos_token, ''); } + if ( + SPECIAL_TOKENS.EOT_TOKEN in this.tokenizerConfig && + this.tokenizerConfig.eot_token + ) { + filtered = filtered.replaceAll(this.tokenizerConfig.eot_token, ''); + } if ( SPECIAL_TOKENS.PAD_TOKEN in this.tokenizerConfig && this.tokenizerConfig.pad_token @@ -269,25 +281,37 @@ export class LLMController { this.isGeneratingCallback(false); } - public async forward(input: string, imagePaths?: string[]): Promise { + public async forward( + input: string, + imagePaths?: string[], + audioWaveforms?: Float32Array[] + ): Promise { if (!this._isReady) { throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded); } if (this._isGenerating) { throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating); } + const hasImages = !!imagePaths && imagePaths.length > 0; + const hasAudio = !!audioWaveforms && audioWaveforms.length > 0; try { this.isGeneratingCallback(true); this.nativeModule.reset(); - const response = - imagePaths && imagePaths.length > 0 - ? await this.nativeModule.generateMultimodal( - input, - imagePaths.map(normalizeImagePath), - this.getImageToken(), - this.onToken - ) - : await this.nativeModule.generate(input, this.onToken); + let response: string; + if (hasImages || hasAudio) { + response = await this.nativeModule.generateMultimodal( + input, + this.onToken, + { + imagePaths: hasImages ? imagePaths!.map(normalizeImagePath) : null, + imageToken: hasImages ? this.getImageToken() : null, + audioWaveforms: hasAudio ? audioWaveforms! : null, + audioToken: hasAudio ? this.getAudioToken() : null, + } + ); + } else { + response = await this.nativeModule.generate(input, this.onToken); + } return this.filterSpecialTokens(response); } catch (e) { throw parseUnknownError(e); @@ -355,7 +379,9 @@ export class LLMController { const imagePaths = messages .filter((m) => m.mediaPath) .map((m) => m.mediaPath!); - + const audioWaveforms = messages + .filter((m) => m.audioWaveform) + .map((m) => m.audioWaveform!); const renderedChat: string = this.applyChatTemplate( messages, this.tokenizerConfig, @@ -365,19 +391,22 @@ export class LLMController { return await this.forward( renderedChat, - imagePaths.length > 0 ? imagePaths : undefined + imagePaths.length > 0 ? imagePaths : undefined, + audioWaveforms.length > 0 ? audioWaveforms : undefined ); } public async sendMessage( message: string, - media?: { imagePath?: string } + media?: { imagePath?: string; audioBuffer?: Float32Array } ): Promise { const mediaPath = media?.imagePath; + const audioBuffer = media?.audioBuffer; const newMessage: Message = { content: message, role: 'user', ...(mediaPath ? { mediaPath } : {}), + ...(audioBuffer ? { audioWaveform: audioBuffer } : {}), }; const updatedHistory = [...this._messageHistory, newMessage]; this.messageHistoryCallback(updatedHistory); @@ -392,7 +421,22 @@ export class LLMController { ); const textTokens = this.nativeModule.countTextTokens(rendered); const imageCount = messages.filter((m) => m.mediaPath).length; - return textTokens + imageCount * (visualTokenCount - 1); + // Audio soft-token expansion: audio_encoder pads samples to + // multiples of this.audioConfig.samplesPerBlock (7680 @ 16 kHz) and emits + // this.audioConfig.tokensPerBlock (~12) soft tokens per padded block. The + // rendered template only contributes 1 token for the audio placeholder, + // so add (expansion - 1) per audio message to match prefill consumption. + const audioTokenExpansion = messages.reduce((acc, m) => { + if (!m.audioWaveform) return acc; + const kBlocks = Math.max( + 1, + Math.ceil(m.audioWaveform.length / this.audioConfig!.samplesPerBlock) + ); + return acc + (this.audioConfig!.tokensPerBlock * kBlocks - 1); + }, 0); + return ( + textTokens + imageCount * (visualTokenCount - 1) + audioTokenExpansion + ); }; const maxContextLength = this.nativeModule.getMaxContextLength(); const messageHistoryWithPrompt = @@ -497,12 +541,17 @@ function normalizeImagePath(path: string): string { * @returns Messages with image-bearing turns rewritten to structured content. */ function messagesForChatTemplate(messages: Message[]): any[] { - return messages.map((m) => - m.mediaPath && typeof m.content === 'string' - ? { - ...m, - content: [{ type: 'image' }, { type: 'text', text: m.content }], - } - : m - ); + return messages.map((m) => { + if (typeof m.content !== 'string') return m; + const hasImage = !!m.mediaPath; + const hasAudio = !!m.audioWaveform; + if (!hasImage && !hasAudio) return m; + const parts: any[] = []; + if (hasImage) parts.push({ type: 'image' }); + if (hasAudio) parts.push({ type: 'audio' }); + parts.push({ type: 'text', text: m.content }); + // Drop the Float32Array on the clone only โ€” passing it into the Jinja + // template engine slows render past 3s. Don't mutate m; + return { ...m, content: parts, audioWaveform: undefined }; + }); } diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 027e237997..a8daef8d91 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -58,11 +58,7 @@ export function useLLM({ (async () => { try { await controllerInstance.load({ - modelSource: model.modelSource, - tokenizerSource: model.tokenizerSource, - tokenizerConfigSource: model.tokenizerConfigSource!, - capabilities: model.capabilities, - defaultGenerationConfig: model.generationConfig, + model: model, onDownloadProgressCallback: setDownloadProgress, }); } catch (e) { @@ -106,7 +102,10 @@ export function useLLM({ ); const sendMessage = useCallback( - (message: string, media?: { imagePath?: string }) => { + ( + message: string, + media?: { imagePath?: string; audioBuffer?: Float32Array } + ) => { setResponse(''); return controllerInstance.sendMessage(message, media); }, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts index bdb5ada699..be6ecb229b 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts @@ -3,6 +3,7 @@ import { Logger } from '../../common/Logger'; import { parseUnknownError } from '../../errors/errorUtils'; import { ResourceSource } from '../../types/common'; import { + AudioConfig, LLMCapability, LLMConfig, LLMModelName, @@ -51,6 +52,7 @@ export class LLMModule { tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource; capabilities?: readonly LLMCapability[]; + audioConfig?: AudioConfig; }, onDownloadProgress: (progress: number) => void = () => {}, tokenCallback?: (token: string) => void, @@ -59,10 +61,14 @@ export class LLMModule { const instance = new LLMModule({ tokenCallback, messageHistoryCallback }); try { await instance.controller.load({ - modelSource: namedSources.modelSource, - tokenizerSource: namedSources.tokenizerSource, - tokenizerConfigSource: namedSources.tokenizerConfigSource, - capabilities: namedSources.capabilities, + model: { + modelName: namedSources.modelName, + modelSource: namedSources.modelSource, + tokenizerSource: namedSources.tokenizerSource, + tokenizerConfigSource: namedSources.tokenizerConfigSource, + capabilities: namedSources.capabilities, + audioConfig: namedSources.audioConfig, + }, onDownloadProgressCallback: onDownloadProgress, }); return instance; @@ -140,10 +146,15 @@ export class LLMModule { * If you want a simple chat with model the consider using `sendMessage` * @param input - Raw input string containing the prompt and conversation history. * @param imagePaths - Optional array of local image paths for multimodal inference. Each entry may be either `file:///absolute/path` or `/absolute/path` โ€” the controller normalizes the path before passing it to native code. + * @param audioWaveforms - Optional array of 16kHz waveforms of audio recordings for multimodal inference. * @returns The generated response as a string. */ - async forward(input: string, imagePaths?: string[]): Promise { - return await this.controller.forward(input, imagePaths); + async forward( + input: string, + imagePaths?: string[], + audioWaveforms?: Float32Array[] + ): Promise { + return await this.controller.forward(input, imagePaths, audioWaveforms); } /** @@ -162,12 +173,12 @@ export class LLMModule { * After model responds it will call `messageHistoryCallback()` containing both user message and model response. * It also returns them. * @param message - The message string to send. - * @param media - Optional media object containing a local image path for multimodal models. + * @param media - Optional media object containing a local image path or 16kHz waveform of an audio recording for multimodal models. * @returns - Updated message history including the new user message and model response. */ async sendMessage( message: string, - media?: { imagePath?: string } + media?: { imagePath?: string; audioBuffer?: Float32Array } ): Promise { await this.controller.sendMessage(message, media); return this.controller.messageHistory; diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 6254775c15..1d8da7bd70 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -5,20 +5,23 @@ import { ResourceSource } from './common'; * Capabilities a multimodal LLM can have. * @category Types */ -export type LLMCapability = 'vision'; +export type LLMCapability = 'vision' | 'audio'; /** * Derives the media argument shape for `sendMessage` from a capabilities tuple. * @category Types */ export type MediaArg = - 'vision' extends C[number] ? { imagePath?: string } : object; + ('vision' extends C[number] ? { imagePath?: string } : object) & + ('audio' extends C[number] ? { audioBuffer?: Float32Array } : object); /** * Union of all built-in LLM model names. * @category Types */ export type LLMModelName = + | 'gemma4-e2b' + | 'gemma4-e2b-multimodal' | 'llama-3.2-3b' | 'llama-3.2-3b-qlora' | 'llama-3.2-3b-spinquant' @@ -62,43 +65,63 @@ export type LLMModelName = | 'bielik-v3.0-1.5b' | 'bielik-v3.0-1.5b-quantized'; +/** + * Audio soft-token expansion constants for audio_encoder. + * @category Types + */ +export interface AudioConfig { + samplesPerBlock: number; + tokensPerBlock: number; +} + +/** + * Properties defining LLMModel. + * @category Types + */ +export interface LLMModel { + /** + * The built-in model name (e.g. `'llama-3.2-3b'`). Used for telemetry and hook reload triggers. + * Pass one of the pre-built LLM constants (e.g. `LLAMA3_2_3B`) to populate all required fields. + */ + modelName: LLMModelName; + /** + * `ResourceSource` that specifies the location of the model binary. + */ + modelSource: ResourceSource; + /** + * `ResourceSource` pointing to the JSON file which contains the tokenizer. + */ + tokenizerSource: ResourceSource; + /** + * `ResourceSource` pointing to the JSON file which contains the tokenizer config. + */ + tokenizerConfigSource: ResourceSource; + /** + * Optional list of modality capabilities the model supports. + * Determines the type of the `media` argument in `sendMessage`. + * Example: `['vision']` enables `sendMessage(text, { imagePath })`. + */ + capabilities?: readonly LLMCapability[]; + /** + * Recommended default generation settings, typically copied from the + * upstream `generation_config.json` or the model card. Applied automatically + * after the native module loads and before any user `configure()` call, + * so callers only need to override the values they want to change. + */ + generationConfig?: GenerationConfig; + /** + * Defines config for audio input modality for multimodal LLMs. + * `capabilities` must include 'audio'. + */ + audioConfig?: AudioConfig; +} + /** * Properties for initializing and configuring a Large Language Model (LLM) instance. * @category Types */ export interface LLMProps { - model: { - /** - * The built-in model name (e.g. `'llama-3.2-3b'`). Used for telemetry and hook reload triggers. - * Pass one of the pre-built LLM constants (e.g. `LLAMA3_2_3B`) to populate all required fields. - */ - modelName: LLMModelName; - /** - * `ResourceSource` that specifies the location of the model binary. - */ - modelSource: ResourceSource; - /** - * `ResourceSource` pointing to the JSON file which contains the tokenizer. - */ - tokenizerSource: ResourceSource; - /** - * `ResourceSource` pointing to the JSON file which contains the tokenizer config. - */ - tokenizerConfigSource: ResourceSource; - /** - * Optional list of modality capabilities the model supports. - * Determines the type of the `media` argument in `sendMessage`. - * Example: `['vision']` enables `sendMessage(text, { imagePath })`. - */ - capabilities?: readonly LLMCapability[]; - /** - * Recommended default generation settings, typically copied from the - * upstream `generation_config.json` or the model card. Applied automatically - * after the native module loads and before any user `configure()` call, - * so callers only need to override the values they want to change. - */ - generationConfig?: GenerationConfig; - }; + model: LLMModel; /** * Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ @@ -289,6 +312,12 @@ export interface Message { * controller normalizes the path before passing it to native code. */ mediaPath?: string; + /** + * Optional fp32 mono 16 kHz PCM buffer. Only valid on `user` messages for + * models with the `'audio'` capability. The controller forwards it to the + * native `generateMultimodal` path. + */ + audioWaveform?: Float32Array; } /** @@ -386,6 +415,7 @@ export interface ContextStrategy { export const SPECIAL_TOKENS = { BOS_TOKEN: 'bos_token', EOS_TOKEN: 'eos_token', + EOT_TOKEN: 'eot_token', UNK_TOKEN: 'unk_token', SEP_TOKEN: 'sep_token', PAD_TOKEN: 'pad_token', diff --git a/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so b/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so index 8c65aa5d85..1e882b92fa 100644 Binary files a/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so and b/packages/react-native-executorch/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so differ diff --git a/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so b/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so index a56a5d20ac..45efcf585d 100644 Binary files a/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so and b/packages/react-native-executorch/third-party/android/libs/executorch/x86_64/libexecutorch.so differ