diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 3837e96955..a23957a7f2 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -215,3 +215,11 @@ MATEUSZ BLAZEFACE Blazeface blazeface +webfetch +prebuild +embedders +upsamples +artefacts +categorisation +chipmunked +autoregressive diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 96fd27ad65..9f0fb0d4b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,3 +37,18 @@ jobs: - name: Build all packages run: yarn workspaces foreach --all --topological-dev run prepare + + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Setup + uses: ./.github/actions/setup + + - name: Typecheck test files + run: yarn workspace react-native-executorch typecheck:tests + + - name: Run API contract tests + run: yarn workspace react-native-executorch test --ci diff --git a/apps/computer-vision/app/text_to_image/index.tsx b/apps/computer-vision/app/text_to_image/index.tsx index bb54c661be..09fe61428c 100644 --- a/apps/computer-vision/app/text_to_image/index.tsx +++ b/apps/computer-vision/app/text_to_image/index.tsx @@ -92,7 +92,7 @@ export default function TextToImageScreen() { try { const start = Date.now(); - const output = await model.generate(input, imageSize, steps); + const output = await model.forward(input, imageSize, steps); if (output.length) { setImage(output); diff --git a/apps/speech/screens/Quiz.tsx b/apps/speech/screens/Quiz.tsx index ffd574d96d..e33309b4f1 100644 --- a/apps/speech/screens/Quiz.tsx +++ b/apps/speech/screens/Quiz.tsx @@ -56,7 +56,9 @@ const createAudioBufferFromVector = ( export const Quiz = ({ onBack }: { onBack: () => void }) => { // --- Hooks & State --- - const model = useTextToSpeech(models.text_to_speech.kokoro.en_us.santa()); + const model = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.santa(), + }); const [shuffledQuestions] = useState(() => shuffleArray(QUESTIONS)); const [currentIndex, setCurrentIndex] = useState(0); diff --git a/apps/speech/screens/TextToSpeechLLMScreen.tsx b/apps/speech/screens/TextToSpeechLLMScreen.tsx index b3a3dad913..6ae5bfb383 100644 --- a/apps/speech/screens/TextToSpeechLLMScreen.tsx +++ b/apps/speech/screens/TextToSpeechLLMScreen.tsx @@ -48,7 +48,9 @@ export const TextToSpeechLLMScreen = ({ onBack }: TextToSpeechLLMProps) => { const [displayText, setDisplayText] = useState(''); const [isTtsStreaming, setIsTtsStreaming] = useState(false); const llm = useLLM({ model: models.llm.llama3_2_1b() }); - const tts = useTextToSpeech(models.text_to_speech.kokoro.en_us.heart()); + const tts = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.heart(), + }); const processedLengthRef = useRef(0); const audioContextRef = useRef(null); diff --git a/apps/speech/screens/TextToSpeechScreen.tsx b/apps/speech/screens/TextToSpeechScreen.tsx index 919076dc35..23cf1c8936 100644 --- a/apps/speech/screens/TextToSpeechScreen.tsx +++ b/apps/speech/screens/TextToSpeechScreen.tsx @@ -81,7 +81,7 @@ export const TextToSpeechScreen = ({ onBack }: { onBack: () => void }) => { const [selectedSpeaker, setSelectedSpeaker] = useState(tts.en_us.heart()); - const model = useTextToSpeech(selectedSpeaker); + const model = useTextToSpeech({ model: selectedSpeaker }); const [inputText, setInputText] = useState(''); const [isPlaying, setIsPlaying] = useState(false); diff --git a/docs/docs/01-fundamentals/02-loading-models.md b/docs/docs/01-fundamentals/02-loading-models.md index 4a2d490415..7dcdc95d7c 100644 --- a/docs/docs/01-fundamentals/02-loading-models.md +++ b/docs/docs/01-fundamentals/02-loading-models.md @@ -90,7 +90,7 @@ initExecutorch({ ### Load from React Native assets folder (for files < 512MB) ```typescript -useExecutorchModule({ +useExecutorch({ modelSource: require('../assets/lfm2_5.pte'), }); ``` @@ -100,7 +100,7 @@ useExecutorchModule({ For files larger than 512MB or when you want to keep size of the app smaller, you can load the model from a remote URL (e.g. HuggingFace). ```typescript -useExecutorchModule({ +useExecutorch({ modelSource: 'https://.../lfm2_5.pte', }); ``` @@ -110,7 +110,7 @@ useExecutorchModule({ If you prefer to delegate the process of obtaining and loading model and tokenizer files to the user, you can use the following method: ```typescript -useExecutorchModule({ +useExecutorch({ modelSource: 'file:///var/mobile/.../lfm2_5.pte', }); ``` diff --git a/docs/docs/01-fundamentals/03-frequently-asked-questions.md b/docs/docs/01-fundamentals/03-frequently-asked-questions.md index 69e3792d41..25afb7706b 100644 --- a/docs/docs/01-fundamentals/03-frequently-asked-questions.md +++ b/docs/docs/01-fundamentals/03-frequently-asked-questions.md @@ -10,7 +10,7 @@ Each hook documentation subpage (useClassification, useLLM, etc.) contains a sup ### How can I run my own AI model? -To run your own model, you need to directly access the underlying [ExecuTorch Module API](https://pytorch.org/executorch/stable/extension-module.html). We provide [React hook](../03-hooks/03-executorch-bindings/useExecutorchModule.md) along with a [TypeScript alternative](../04-typescript-api/03-executorch-bindings/ExecutorchModule.md), which serve as a way to use the aforementioned API without the need of diving into native code. In order to get a model in a format runnable by the runtime, you'll need to get your hands dirty with some ExecuTorch knowledge. For more guides on exporting models, please refer to the [ExecuTorch tutorials](https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html). Once you obtain your model in a `.pte` format, you can run it with `useExecuTorchModule` and `ExecuTorchModule`. +To run your own model, you need to directly access the underlying [ExecuTorch Module API](https://pytorch.org/executorch/stable/extension-module.html). We provide [React hook](../03-hooks/03-executorch-bindings/useExecutorch.md) along with a [TypeScript alternative](../04-typescript-api/03-executorch-bindings/ExecutorchModule.md), which serve as a way to use the aforementioned API without the need of diving into native code. In order to get a model in a format runnable by the runtime, you'll need to get your hands dirty with some ExecuTorch knowledge. For more guides on exporting models, please refer to the [ExecuTorch tutorials](https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html). Once you obtain your model in a `.pte` format, you can run it with `useExecutorch` and `ExecutorchModule`. ### How React Native ExecuTorch works under the hood? diff --git a/docs/docs/03-hooks/01-natural-language-processing/useTextToSpeech.md b/docs/docs/03-hooks/01-natural-language-processing/useTextToSpeech.md index fe0dd3b2ee..c47edbfe27 100644 --- a/docs/docs/03-hooks/01-natural-language-processing/useTextToSpeech.md +++ b/docs/docs/03-hooks/01-natural-language-processing/useTextToSpeech.md @@ -36,7 +36,9 @@ You can play the generated waveform in any way most suitable to you; however, in import { models, useTextToSpeech } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; -const model = useTextToSpeech(models.text_to_speech.kokoro.en_us.heart()); +const model = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.heart(), +}); const audioContext = new AudioContext({ sampleRate: 24000 }); @@ -56,15 +58,13 @@ const handleSpeech = async (text: string) => { ### Arguments -`useTextToSpeech` takes [`TextToSpeechModelConfig`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md) that consists of: +`useTextToSpeech` takes [`TextToSpeechProps`](../../06-api-reference/interfaces/TextToSpeechProps.md), an object containing: -- `model` of type [`TextToSpeechModelSources`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md) containing the [`durationPredictorSource`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md#durationpredictorsource), [`synthesizerSource`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md#synthesizersource), and [`modelName`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md#modelname). -- [`voiceSource`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md#voicesource) of type [`ResourceSource`](../../06-api-reference/type-aliases/ResourceSource.md) - configuration of specific voice used in TTS. -- [`phonemizerConfig`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md#phonemizerconfig) of type [`TextToSpeechPhonemizerConfig`](../../06-api-reference/interfaces/TextToSpeechPhonemizerConfig.md) - configuration of the phonemizer. - -`useTextToSpeech`'s second optional argument is an object with: - -- `preventLoad` which prevents auto-loading of the model. +- `model` of type [`TextToSpeechModelConfig`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md), which itself consists of: + - [`model`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md#model) of type [`TextToSpeechModelSources`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md) — bundles the [`durationPredictorSource`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md#durationpredictorsource), [`synthesizerSource`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md#synthesizersource), and [`modelName`](../../06-api-reference/type-aliases/TextToSpeechModelSources.md#modelname). + - [`voiceSource`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md#voicesource) of type [`ResourceSource`](../../06-api-reference/type-aliases/ResourceSource.md) — configuration of the specific voice used in TTS. + - [`phonemizerConfig`](../../06-api-reference/interfaces/TextToSpeechModelConfig.md#phonemizerconfig) of type [`TextToSpeechPhonemizerConfig`](../../06-api-reference/interfaces/TextToSpeechPhonemizerConfig.md) — configuration of the phonemizer. +- An optional flag `preventLoad` which prevents auto-loading of the model. You need more details? Check the following resources: @@ -115,7 +115,9 @@ import { models, useTextToSpeech } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; export default function App() { - const tts = useTextToSpeech(models.text_to_speech.kokoro.en_us.heart()); + const tts = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.heart(), + }); const generateAudio = async () => { const audioData = await tts.forward({ @@ -150,7 +152,9 @@ import { models, useTextToSpeech } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; export default function App() { - const tts = useTextToSpeech(models.text_to_speech.kokoro.en_us.heart()); + const tts = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.heart(), + }); const contextRef = useRef(new AudioContext({ sampleRate: 24000 })); @@ -192,7 +196,9 @@ import React from 'react'; import { Button, View } from 'react-native'; import { models, useTextToSpeech } from 'react-native-executorch'; export default function App() { - const tts = useTextToSpeech(models.text_to_speech.kokoro.en_us.heart()); + const tts = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.heart(), + }); const synthesizePhonemes = async () => { // Example phonemes for "Hello" diff --git a/docs/docs/03-hooks/02-computer-vision/useTextToImage.md b/docs/docs/03-hooks/02-computer-vision/useTextToImage.md index 5315979983..da874ffa4a 100644 --- a/docs/docs/03-hooks/02-computer-vision/useTextToImage.md +++ b/docs/docs/03-hooks/02-computer-vision/useTextToImage.md @@ -26,7 +26,7 @@ const model = useTextToImage({ const input = 'a castle'; try { - const image = await model.generate(input); + const image = await model.forward(input); } catch (error) { console.error(error); } @@ -52,7 +52,7 @@ You need more details? Check the following resources: ## Running the model -To run the model, you can use the [`generate`](../../06-api-reference/interfaces/TextToImageType.md#generate) method. It accepts four arguments: a text prompt describing the requested image, a size of the image in pixels, a number of denoising steps, and an optional seed value, which enables reproducibility of the results. +To run the model, you can use the [`forward`](../../06-api-reference/interfaces/TextToImageType.md#forward) method. It accepts four arguments: a text prompt describing the requested image, a size of the image in pixels, a number of denoising steps, and an optional seed value, which enables reproducibility of the results. The image size must be a multiple of 32 due to the architecture of the U-Net and VAE models. The seed should be a positive integer. @@ -76,13 +76,13 @@ function App() { const numSteps = 25; try { - image = await model.generate(input, imageSize, numSteps); + image = await model.forward(input, imageSize, numSteps); } catch (error) { console.error(error); } //... - // `generate` returns a `file://` URI to the PNG saved on disk. + // `forward` returns a `file://` URI to the PNG saved on disk. return ; } ``` diff --git a/docs/docs/03-hooks/03-executorch-bindings/useExecutorchModule.md b/docs/docs/03-hooks/03-executorch-bindings/useExecutorch.md similarity index 78% rename from docs/docs/03-hooks/03-executorch-bindings/useExecutorchModule.md rename to docs/docs/03-hooks/03-executorch-bindings/useExecutorch.md index 7429b4b60e..0019e3e983 100644 --- a/docs/docs/03-hooks/03-executorch-bindings/useExecutorchModule.md +++ b/docs/docs/03-hooks/03-executorch-bindings/useExecutorch.md @@ -1,8 +1,8 @@ --- -title: useExecutorchModule +title: useExecutorch --- -useExecutorchModule provides React Native bindings to the ExecuTorch [Module API](https://pytorch.org/executorch/stable/extension-module.html) directly from JavaScript. +useExecutorch provides React Native bindings to the ExecuTorch [Module API](https://pytorch.org/executorch/stable/extension-module.html) directly from JavaScript. :::info These bindings are primarily intended for custom model integration where no dedicated hook exists. If you are considering using a provided model, first verify whether a dedicated hook is available. Dedicated hooks simplify the implementation process by managing necessary pre and post-processing automatically. Utilizing these can save you effort and reduce complexity, ensuring you do not implement additional handling that is already covered. @@ -10,15 +10,15 @@ These bindings are primarily intended for custom model integration where no dedi ## API Reference -- For detailed API Reference for `useExecutorchModule` see: [`useExecutorchModule` API Reference](../../06-api-reference/functions/useExecutorchModule.md). +- For detailed API Reference for `useExecutorch` see: [`useExecutorch` API Reference](../../06-api-reference/functions/useExecutorch.md). ## Initializing ExecuTorch Module -You can initialize the ExecuTorch module in your JavaScript application using the `useExecutorchModule` hook. This hook facilitates the loading of models from the specified source and prepares them for use. +You can initialize the ExecuTorch module in your JavaScript application using the `useExecutorch` hook. This hook facilitates the loading of models from the specified source and prepares them for use. ```typescript -import { useExecutorchModule } from 'react-native-executorch'; -const executorchModule = useExecutorchModule({ +import { useExecutorch } from 'react-native-executorch'; +const executorchModule = useExecutorch({ modelSource: require('../assets/models/model.pte'), }); ``` @@ -29,19 +29,19 @@ For more information on loading resources, take a look at [loading models](../.. ### Arguments -`useExecutorchModule` takes [`ExecutorchModuleProps`](../../06-api-reference/interfaces/ExecutorchModuleProps.md) that consists of: +`useExecutorch` takes [`ExecutorchModuleProps`](../../06-api-reference/interfaces/ExecutorchModuleProps.md) that consists of: - `model` containing [`modelSource`](../../06-api-reference/interfaces/ExecutorchModuleProps.md#modelsource). - An optional flag [`preventLoad`](../../06-api-reference/interfaces/ExecutorchModuleProps.md#preventload) which prevents auto-loading of the model. You need more details? Check the following resources: -- For detailed information about `useExecutorchModule` arguments check this section: [`useExecutorchModule` arguments](../../06-api-reference/functions/useExecutorchModule.md#parameters). +- For detailed information about `useExecutorch` arguments check this section: [`useExecutorch` arguments](../../06-api-reference/functions/useExecutorch.md#parameters). - For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ### Returns -`useExecutorchModule` returns an object called `ExecutorchModuleType` containing bunch of functions to interact with arbitrarily chosen models. To get more details please read: [`ExecutorchModuleType` API Reference](../../06-api-reference/interfaces/ExecutorchModuleType.md). +`useExecutorch` returns an object called `ExecutorchModuleType` containing bunch of functions to interact with arbitrarily chosen models. To get more details please read: [`ExecutorchModuleType` API Reference](../../06-api-reference/interfaces/ExecutorchModuleType.md). ## TensorPtr @@ -62,13 +62,9 @@ This example demonstrates the integration and usage of the ExecuTorch bindings w First, import the necessary functions from the `react-native-executorch` package and initialize the ExecuTorch module with the specified style transfer model. ```typescript -import { - models, - useExecutorchModule, - ScalarType, -} from 'react-native-executorch'; +import { models, useExecutorch, ScalarType } from 'react-native-executorch'; // Initialize the executorch module with the predefined style transfer model. -const executorchModule = useExecutorchModule({ +const executorchModule = useExecutorch({ modelSource: models.style_transfer.candy(), }); ``` diff --git a/docs/docs/04-typescript-api/03-executorch-bindings/ExecutorchModule.md b/docs/docs/04-typescript-api/03-executorch-bindings/ExecutorchModule.md index 252f7fa74a..0d5d3e3110 100644 --- a/docs/docs/04-typescript-api/03-executorch-bindings/ExecutorchModule.md +++ b/docs/docs/04-typescript-api/03-executorch-bindings/ExecutorchModule.md @@ -5,7 +5,7 @@ title: ExecutorchModule ExecutorchModule provides TypeScript bindings for the underlying ExecuTorch [Module API](https://pytorch.org/executorch/stable/extension-module.html). :::tip -For React applications, consider using the [`useExecutorchModule`](../../03-hooks/03-executorch-bindings/useExecutorchModule.md) hook instead, which provides automatic state management, loading progress tracking, and cleanup on unmount. +For React applications, consider using the [`useExecutorch`](../../03-hooks/03-executorch-bindings/useExecutorch.md) hook instead, which provides automatic state management, loading progress tracking, and cleanup on unmount. ::: ## API Reference @@ -23,11 +23,10 @@ const inputTensor = { scalarType: ScalarType.FLOAT, }; -// Creating an instance -const model = new ExecutorchModule(); - -// Loading the model -await model.load(models.style_transfer.candy()); +// Creating and loading the model in a single step +const model = await ExecutorchModule.fromModelSource( + models.style_transfer.candy() +); // Running the forward method const output = await model.forward([inputTensor]); @@ -57,13 +56,13 @@ First, import the necessary functions from the `react-native-executorch` package ```typescript import { models, ExecutorchModule, ScalarType } from 'react-native-executorch'; -// Initialize the executorch module -const executorchModule = new ExecutorchModule(); - -// Load the model with optional download progress callback -await executorchModule.load(models.style_transfer.candy(), (progress) => { - console.log(`Download progress: ${progress}%`); -}); +// Initialize and load the executorch module with optional download progress callback. +const executorchModule = await ExecutorchModule.fromModelSource( + models.style_transfer.candy(), + (progress) => { + console.log(`Download progress: ${progress}%`); + } +); ``` ### Setting up input parameters diff --git a/docs/docs/05-utilities/model-registry.md b/docs/docs/05-utilities/model-registry.md index 3611731235..7d265a66ba 100644 --- a/docs/docs/05-utilities/model-registry.md +++ b/docs/docs/05-utilities/model-registry.md @@ -96,7 +96,9 @@ const styled = useStyleTransfer({ ```typescript import { models, useTextToSpeech } from 'react-native-executorch'; -const tts = useTextToSpeech(models.text_to_speech.kokoro.en_us.heart()); +const tts = useTextToSpeech({ + model: models.text_to_speech.kokoro.en_us.heart(), +}); // Other languages: // models.text_to_speech.kokoro.en_gb.emma() // models.text_to_speech.kokoro.fr.siwis() diff --git a/packages/react-native-executorch/__tests__/api/__snapshots__/apiSurface.test.ts.snap b/packages/react-native-executorch/__tests__/api/__snapshots__/apiSurface.test.ts.snap new file mode 100644 index 0000000000..7d16a45536 --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/__snapshots__/apiSurface.test.ts.snap @@ -0,0 +1,306 @@ +// Jest Snapshot v1, https://jestjs.io/docs/snapshot-testing + +exports[`Public API surface export names match snapshot 1`] = ` +[ + "ALL_MINILM_L6_V2", + "ALL_MPNET_BASE_V2", + "BIELIK_V3_0_1_5B", + "BIELIK_V3_0_1_5B_QUANTIZED", + "BK_SDM_TINY_VPRED_256", + "BK_SDM_TINY_VPRED_512", + "BaseResourceFetcherClass", + "CLIP_VIT_BASE_PATCH32_IMAGE", + "CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED", + "CLIP_VIT_BASE_PATCH32_TEXT", + "ClassificationModule", + "CocoKeypoint", + "CocoLabel", + "CocoLabelYolo", + "DEEPLAB_V3_MOBILENET_V3_LARGE", + "DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED", + "DEEPLAB_V3_RESNET101", + "DEEPLAB_V3_RESNET101_QUANTIZED", + "DEEPLAB_V3_RESNET50", + "DEEPLAB_V3_RESNET50_QUANTIZED", + "DEFAULT_CHAT_CONFIG", + "DEFAULT_CONTEXT_BUFFER_TOKENS", + "DEFAULT_MESSAGE_HISTORY", + "DEFAULT_STRUCTURED_OUTPUT_PROMPT", + "DEFAULT_SYSTEM_PROMPT", + "DISTILUSE_BASE_MULTILINGUAL_CASED_V2_8DA4W", + "DISTILUSE_BASE_MULTILINGUAL_CASED_V2_8DA4W_MODEL", + "DISTILUSE_BASE_MULTILINGUAL_CASED_V2_TOKENIZER", + "DeeplabLabel", + "DownloadStatus", + "EFFICIENTNET_V2_S", + "EFFICIENTNET_V2_S_COREML_FP16_MODEL", + "EFFICIENTNET_V2_S_COREML_FP32_MODEL", + "EFFICIENTNET_V2_S_QUANTIZED", + "EFFICIENTNET_V2_S_XNNPACK_FP32_MODEL", + "EFFICIENTNET_V2_S_XNNPACK_INT8_MODEL", + "ExecutorchModule", + "FASTSAM_S", + "FASTSAM_S_COREML_FP16_MODEL", + "FASTSAM_S_XNNPACK_FP32_MODEL", + "FASTSAM_X", + "FASTSAM_X_COREML_FP16_MODEL", + "FASTSAM_X_XNNPACK_FP32_MODEL", + "FCN_RESNET101", + "FCN_RESNET101_QUANTIZED", + "FCN_RESNET50", + "FCN_RESNET50_QUANTIZED", + "FSMN_VAD", + "FastSAMLabel", + "HAMMER2_1_0_5B", + "HAMMER2_1_0_5B_QUANTIZED", + "HAMMER2_1_1_5B", + "HAMMER2_1_1_5B_QUANTIZED", + "HAMMER2_1_3B", + "HAMMER2_1_3B_QUANTIZED", + "HTTP_CODE", + "IMAGENET1K_MEAN", + "IMAGENET1K_STD", + "ImageEmbeddingsModule", + "Imagenet1kLabel", + "InstanceSegmentationModule", + "KOKORO_AMERICAN_ENGLISH_FEMALE_HEART", + "KOKORO_AMERICAN_ENGLISH_FEMALE_RIVER", + "KOKORO_AMERICAN_ENGLISH_FEMALE_SARAH", + "KOKORO_AMERICAN_ENGLISH_MALE_ADAM", + "KOKORO_AMERICAN_ENGLISH_MALE_MICHAEL", + "KOKORO_AMERICAN_ENGLISH_MALE_SANTA", + "KOKORO_BRITISH_ENGLISH_FEMALE_EMMA", + "KOKORO_BRITISH_ENGLISH_MALE_DANIEL", + "KOKORO_FRENCH_FEMALE_SIWIS", + "KOKORO_GERMAN", + "KOKORO_GERMAN_FEMALE_ANNA", + "KOKORO_HINDI_FEMALE_ALPHA", + "KOKORO_HINDI_MALE_OMEGA", + "KOKORO_HINDI_MALE_PSI", + "KOKORO_ITALIAN_FEMALE_SARA", + "KOKORO_ITALIAN_MALE_NICOLA", + "KOKORO_POLISH", + "KOKORO_POLISH_MALE_MATEUSZ", + "KOKORO_PORTUGUESE_FEMALE_DORA", + "KOKORO_PORTUGUESE_MALE_SANTA", + "KOKORO_SPANISH_FEMALE_DORA", + "KOKORO_SPANISH_MALE_ALEX", + "KOKORO_STANDARD", + "LFM2_5_1_2B_INSTRUCT", + "LFM2_5_1_2B_INSTRUCT_QUANTIZED", + "LFM2_5_350M", + "LFM2_5_350M_QUANTIZED", + "LFM2_5_VL_1_6B_QUANTIZED", + "LFM2_5_VL_450M_QUANTIZED", + "LFM2_VL_1_6B_QUANTIZED", + "LFM2_VL_450M_QUANTIZED", + "LLAMA3_2_1B", + "LLAMA3_2_1B_QLORA", + "LLAMA3_2_1B_SPINQUANT", + "LLAMA3_2_3B", + "LLAMA3_2_3B_QLORA", + "LLAMA3_2_3B_SPINQUANT", + "LLMModule", + "LRASPP_MOBILENET_V3_LARGE", + "LRASPP_MOBILENET_V3_LARGE_QUANTIZED", + "Logger", + "MULTI_QA_MINILM_L6_COS_V1", + "MULTI_QA_MPNET_BASE_DOT_V1", + "MessageCountContextStrategy", + "NoopContextStrategy", + "OCRModule", + "OCR_ABAZA", + "OCR_ADYGHE", + "OCR_AFRIKAANS", + "OCR_ALBANIAN", + "OCR_AVAR", + "OCR_AZERBAIJANI", + "OCR_BELARUSIAN", + "OCR_BOSNIAN", + "OCR_BULGARIAN", + "OCR_CHECHEN", + "OCR_CROATIAN", + "OCR_CZECH", + "OCR_DANISH", + "OCR_DARGWA", + "OCR_DUTCH", + "OCR_ENGLISH", + "OCR_ESTONIAN", + "OCR_FRENCH", + "OCR_GERMAN", + "OCR_HUNGARIAN", + "OCR_ICELANDIC", + "OCR_INDONESIAN", + "OCR_INGUSH", + "OCR_IRISH", + "OCR_ITALIAN", + "OCR_JAPANESE", + "OCR_KANNADA", + "OCR_KARBADIAN", + "OCR_KOREAN", + "OCR_KURDISH", + "OCR_LAK", + "OCR_LATIN", + "OCR_LATVIAN", + "OCR_LEZGHIAN", + "OCR_LITHUANIAN", + "OCR_MALAY", + "OCR_MALTESE", + "OCR_MAORI", + "OCR_MONGOLIAN", + "OCR_NORWEGIAN", + "OCR_OCCITAN", + "OCR_PALI", + "OCR_POLISH", + "OCR_PORTUGUESE", + "OCR_ROMANIAN", + "OCR_RUSSIAN", + "OCR_SERBIAN_CYRILLIC", + "OCR_SERBIAN_LATIN", + "OCR_SIMPLIFIED_CHINESE", + "OCR_SLOVAK", + "OCR_SLOVENIAN", + "OCR_SPANISH", + "OCR_SWAHILI", + "OCR_SWEDISH", + "OCR_TABASSARAN", + "OCR_TAGALOG", + "OCR_TAJIK", + "OCR_TELUGU", + "OCR_TURKISH", + "OCR_UKRAINIAN", + "OCR_UZBEK", + "OCR_VIETNAMESE", + "OCR_WELSH", + "ObjectDetectionModule", + "PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED", + "PHI_4_MINI_4B", + "PHI_4_MINI_4B_QUANTIZED", + "PRIVACY_FILTER_NEMOTRON", + "PRIVACY_FILTER_OPENAI", + "PoseEstimationModule", + "PrivacyFilterModule", + "QWEN2_5_0_5B", + "QWEN2_5_0_5B_QUANTIZED", + "QWEN2_5_1_5B", + "QWEN2_5_1_5B_QUANTIZED", + "QWEN2_5_3B", + "QWEN2_5_3B_QUANTIZED", + "QWEN3_0_6B", + "QWEN3_0_6B_QUANTIZED", + "QWEN3_1_7B", + "QWEN3_1_7B_QUANTIZED", + "QWEN3_4B", + "QWEN3_4B_QUANTIZED", + "QWEN3_5_0_8B_QUANTIZED", + "QWEN3_5_2B_QUANTIZED", + "RF_DETR_NANO", + "RF_DETR_NANO_COREML_INT8_MODEL", + "RF_DETR_NANO_SEG", + "RF_DETR_NANO_SEG_COREML_INT8_MODEL", + "RF_DETR_NANO_SEG_XNNPACK_FP32_MODEL", + "RF_DETR_NANO_XNNPACK_FP32_MODEL", + "ResourceFetcher", + "ResourceFetcherUtils", + "RnExecutorchError", + "RnExecutorchErrorCode", + "SELFIE_SEGMENTATION", + "SMOLLM2_1_135M", + "SMOLLM2_1_135M_QUANTIZED", + "SMOLLM2_1_1_7B", + "SMOLLM2_1_1_7B_QUANTIZED", + "SMOLLM2_1_360M", + "SMOLLM2_1_360M_QUANTIZED", + "SPECIAL_TOKENS", + "SSDLITE_320_MOBILENET_V3_LARGE", + "SSDLITE_320_MOBILENET_V3_LARGE_COREML_FP16_MODEL", + "SSDLITE_320_MOBILENET_V3_LARGE_XNNPACK_FP32_MODEL", + "STYLE_TRANSFER_CANDY", + "STYLE_TRANSFER_CANDY_QUANTIZED", + "STYLE_TRANSFER_MOSAIC", + "STYLE_TRANSFER_MOSAIC_QUANTIZED", + "STYLE_TRANSFER_RAIN_PRINCESS", + "STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED", + "STYLE_TRANSFER_UDNIE", + "STYLE_TRANSFER_UDNIE_QUANTIZED", + "ScalarType", + "SelfieSegmentationLabel", + "SemanticSegmentationModule", + "SlidingWindowContextStrategy", + "SourceType", + "SpeechToTextModule", + "StyleTransferModule", + "TextEmbeddingsModule", + "TextToImageModule", + "TextToSpeechModule", + "TokenizerModule", + "VADModule", + "VerticalOCRModule", + "WHISPER_BASE", + "WHISPER_BASE_EN", + "WHISPER_BASE_EN_MODEL_COREML", + "WHISPER_BASE_EN_MODEL_XNNPACK", + "WHISPER_BASE_EN_TOKENIZER", + "WHISPER_BASE_MODEL_COREML", + "WHISPER_BASE_MODEL_XNNPACK", + "WHISPER_BASE_TOKENIZER", + "WHISPER_SMALL", + "WHISPER_SMALL_EN", + "WHISPER_SMALL_EN_MODEL_COREML", + "WHISPER_SMALL_EN_MODEL_XNNPACK", + "WHISPER_SMALL_EN_TOKENIZER", + "WHISPER_SMALL_MODEL_COREML", + "WHISPER_SMALL_MODEL_XNNPACK", + "WHISPER_SMALL_TOKENIZER", + "WHISPER_TINY", + "WHISPER_TINY_EN", + "WHISPER_TINY_EN_MODEL_COREML", + "WHISPER_TINY_EN_MODEL_XNNPACK", + "WHISPER_TINY_EN_TOKENIZER", + "WHISPER_TINY_MODEL_COREML", + "WHISPER_TINY_MODEL_XNNPACK", + "WHISPER_TINY_TOKENIZER", + "YOLO26L", + "YOLO26L_SEG", + "YOLO26M", + "YOLO26M_SEG", + "YOLO26N", + "YOLO26N_POSE", + "YOLO26N_SEG", + "YOLO26S", + "YOLO26S_SEG", + "YOLO26X", + "YOLO26X_SEG", + "cleanupExecutorch", + "fixAndValidateStructuredOutput", + "getModelNameForUrl", + "getStructuredOutputPrompt", + "initExecutorch", + "isAvailable", + "models", + "parseToolCall", + "selectByBox", + "selectByPoint", + "selectByText", + "styleTransferUrls", + "useClassification", + "useExecutorch", + "useExecutorchModule", + "useImageEmbeddings", + "useInstanceSegmentation", + "useLLM", + "useOCR", + "useObjectDetection", + "usePoseEstimation", + "usePrivacyFilter", + "useSemanticSegmentation", + "useSpeechToText", + "useStyleTransfer", + "useTextEmbeddings", + "useTextToImage", + "useTextToSpeech", + "useTokenizer", + "useVAD", + "useVerticalOCR", +] +`; diff --git a/packages/react-native-executorch/__tests__/api/apiSurface.test.ts b/packages/react-native-executorch/__tests__/api/apiSurface.test.ts new file mode 100644 index 0000000000..13ea6c135a --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/apiSurface.test.ts @@ -0,0 +1,21 @@ +import * as RNE from '../../src'; + +// Snapshots the sorted list of public export names from +// `src/index.ts`. Any addition or removal flips the snapshot so the change is +// surfaced in the diff — a deliberate API tweak just needs `--updateSnapshot`, +// an accidental break does not slip through. +describe('Public API surface', () => { + it('export names match snapshot', () => { + const exportNames = Object.keys(RNE).sort(); + expect(exportNames).toMatchSnapshot(); + }); + + it('every export is non-undefined', () => { + for (const [name, value] of Object.entries(RNE)) { + expect({ name, defined: value !== undefined }).toEqual({ + name, + defined: true, + }); + } + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/errorCodes.test.ts b/packages/react-native-executorch/__tests__/api/errorCodes.test.ts new file mode 100644 index 0000000000..e9ffc35bef --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/errorCodes.test.ts @@ -0,0 +1,45 @@ +import { RnExecutorchErrorCode } from '../../src/errors/ErrorCodes'; +import { RnExecutorchError } from '../../src/errors/errorUtils'; + +// TypeScript enums emit a numeric reverse-mapping: `Enum[42] === 'KeyName'`. +// We use that to walk the enum at runtime as `[name, code]` pairs. +function enumEntries(): Array<[string, number]> { + return Object.entries(RnExecutorchErrorCode) + .filter(([, v]) => typeof v === 'number') + .map(([k, v]) => [k, v as number]); +} + +describe('RnExecutorchErrorCode', () => { + const entries = enumEntries(); + + it('contains entries', () => { + expect(entries.length).toBeGreaterThan(0); + }); + + it('every numeric code is unique', () => { + const codes = entries.map(([, v]) => v); + const dupes = codes.filter((c, i) => codes.indexOf(c) !== i); + expect(dupes).toEqual([]); + }); + + it.each(entries)('%s = %s is a non-negative integer', (_name, code) => { + expect(Number.isInteger(code)).toBe(true); + expect(code).toBeGreaterThanOrEqual(0); + }); + + it.each(entries)('%s = %s has a working reverse lookup', (name, code) => { + expect( + (RnExecutorchErrorCode as unknown as Record)[code] + ).toBe(name); + }); + + it.each(entries)( + 'new RnExecutorchError(%s = %s) produces a non-empty message', + (_name, code) => { + const err = new RnExecutorchError(code); + expect(typeof err.message).toBe('string'); + expect(err.message.length).toBeGreaterThan(0); + expect(err.code).toBe(code); + } + ); +}); diff --git a/packages/react-native-executorch/__tests__/api/hookContracts.test.ts b/packages/react-native-executorch/__tests__/api/hookContracts.test.ts new file mode 100644 index 0000000000..d9ad127f2f --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/hookContracts.test.ts @@ -0,0 +1,80 @@ +import type { + RnExecutorchError, + useClassification, + useExecutorch, + useExecutorchModule, + useImageEmbeddings, + useInstanceSegmentation, + useLLM, + useObjectDetection, + useOCR, + usePoseEstimation, + usePrivacyFilter, + useSemanticSegmentation, + useSpeechToText, + useStyleTransfer, + useTextEmbeddings, + useTextToImage, + useTextToSpeech, + useTokenizer, + useVAD, + useVerticalOCR, +} from '../../src'; + +// Every public `useXxx` hook is expected to expose at least this state shape. +// The contract is enforced at compile time via `satisfies` below — any hook +// whose return type drifts from this contract will fail `tsc -p +// tsconfig.test.json`, naming the offending hook in the error. +type HookBaseState = { + error: RnExecutorchError | null; + isReady: boolean; + isGenerating: boolean; + downloadProgress: number; +}; + +type HookReturn = T extends (...args: never[]) => infer R ? R : never; + +// Allocate a stub value of each hook's return type and assert the whole map +// satisfies `Record`. If a hook's return type does not +// include the base state, `tsc` errors at the `satisfies` clause and reports +// the failing entry. +const HOOK_RETURN_TYPES = { + // computer vision + useClassification: null as unknown as HookReturn, + useImageEmbeddings: null as unknown as HookReturn, + useInstanceSegmentation: null as unknown as HookReturn< + typeof useInstanceSegmentation + >, + useObjectDetection: null as unknown as HookReturn, + useOCR: null as unknown as HookReturn, + usePoseEstimation: null as unknown as HookReturn, + useSemanticSegmentation: null as unknown as HookReturn< + typeof useSemanticSegmentation + >, + useStyleTransfer: null as unknown as HookReturn, + useTextToImage: null as unknown as HookReturn, + useVerticalOCR: null as unknown as HookReturn, + // general + useExecutorch: null as unknown as HookReturn, + useExecutorchModule: null as unknown as HookReturn< + typeof useExecutorchModule + >, + // natural language processing + useLLM: null as unknown as HookReturn, + usePrivacyFilter: null as unknown as HookReturn, + useSpeechToText: null as unknown as HookReturn, + useTextEmbeddings: null as unknown as HookReturn, + useTextToSpeech: null as unknown as HookReturn, + useTokenizer: null as unknown as HookReturn, + useVAD: null as unknown as HookReturn, +} satisfies Record; + +describe('Hook return contracts', () => { + it('every public hook return type satisfies HookBaseState (compile-time)', () => { + // The real assertion is the `satisfies` clause above, checked by tsc. + // This runtime test exists so the file appears in the Jest report and + // so the symbol is referenced (preventing dead-code elimination + // surprises and surfacing import-time regressions). + expect(Object.keys(HOOK_RETURN_TYPES).length).toBeGreaterThan(0); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/hookPropsContract.test.ts b/packages/react-native-executorch/__tests__/api/hookPropsContract.test.ts new file mode 100644 index 0000000000..6370ed6a5a --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/hookPropsContract.test.ts @@ -0,0 +1,131 @@ +import type { + ClassificationProps, + ExecutorchModuleProps, + ImageEmbeddingsProps, + InstanceSegmentationProps, + LLMProps, + ObjectDetectionProps, + OCRProps, + PoseEstimationProps, + PrivacyFilterProps, + SemanticSegmentationProps, + SpeechToTextProps, + StyleTransferProps, + TextEmbeddingsProps, + TextToImageProps, + TextToSpeechProps, + TokenizerProps, + VADProps, + VerticalOCRProps, +} from '../../src'; +import type { + useClassification, + useExecutorch, + useExecutorchModule, + useImageEmbeddings, + useInstanceSegmentation, + useLLM, + useObjectDetection, + useOCR, + usePoseEstimation, + usePrivacyFilter, + useSemanticSegmentation, + useSpeechToText, + useStyleTransfer, + useTextEmbeddings, + useTextToImage, + useTextToSpeech, + useTokenizer, + useVAD, + useVerticalOCR, +} from '../../src'; + +// ───────────────────────────────────────────────────────────────────────────── +// preventLoad presence on every *Props type. tsc errors on the `satisfies` +// clause if a Props type drops the field. +// ───────────────────────────────────────────────────────────────────────────── + +type HasPreventLoad = { preventLoad?: boolean }; + +const PROPS_TYPES_WITH_PREVENT_LOAD = { + ClassificationProps: null as unknown as ClassificationProps, + ExecutorchModuleProps: null as unknown as ExecutorchModuleProps, + ImageEmbeddingsProps: null as unknown as ImageEmbeddingsProps, + InstanceSegmentationProps: + null as unknown as InstanceSegmentationProps, + LLMProps: null as unknown as LLMProps, + ObjectDetectionProps: null as unknown as ObjectDetectionProps, + OCRProps: null as unknown as OCRProps, + PoseEstimationProps: null as unknown as PoseEstimationProps, + PrivacyFilterProps: null as unknown as PrivacyFilterProps, + SemanticSegmentationProps: + null as unknown as SemanticSegmentationProps, + SpeechToTextProps: null as unknown as SpeechToTextProps, + StyleTransferProps: null as unknown as StyleTransferProps, + TextEmbeddingsProps: null as unknown as TextEmbeddingsProps, + TextToImageProps: null as unknown as TextToImageProps, + TextToSpeechProps: null as unknown as TextToSpeechProps, + TokenizerProps: null as unknown as TokenizerProps, + VADProps: null as unknown as VADProps, + VerticalOCRProps: null as unknown as VerticalOCRProps, +} satisfies Record; + +// ───────────────────────────────────────────────────────────────────────────── +// Hook call shape consistency. Every public `useXxx` takes a single object +// argument, so the second positional parameter must resolve to `undefined` at +// the type level. +// ───────────────────────────────────────────────────────────────────────────── + +type SecondParam = F extends (...args: infer A) => unknown ? A[1] : never; + +// `unknown` for the OK case, an error-bearing object literal otherwise. Used +// as the rhs of `as` so any non-OK type yields a tsc error. +type AssertSingleArg = + SecondParam extends undefined + ? unknown + : { + ERROR: 'hook should take a single object argument'; + actualSecondParam: SecondParam; + }; + +const _HOOKS_TAKE_SINGLE_ARG = { + useClassification: undefined as AssertSingleArg, + useExecutorch: undefined as AssertSingleArg, + useExecutorchModule: undefined as AssertSingleArg, + useImageEmbeddings: undefined as AssertSingleArg, + useInstanceSegmentation: undefined as AssertSingleArg< + typeof useInstanceSegmentation + >, + useLLM: undefined as AssertSingleArg, + useObjectDetection: undefined as AssertSingleArg, + useOCR: undefined as AssertSingleArg, + usePoseEstimation: undefined as AssertSingleArg, + usePrivacyFilter: undefined as AssertSingleArg, + useSemanticSegmentation: undefined as AssertSingleArg< + typeof useSemanticSegmentation + >, + useSpeechToText: undefined as AssertSingleArg, + useStyleTransfer: undefined as AssertSingleArg, + useTextEmbeddings: undefined as AssertSingleArg, + useTextToImage: undefined as AssertSingleArg, + useTextToSpeech: undefined as AssertSingleArg, + useTokenizer: undefined as AssertSingleArg, + useVAD: undefined as AssertSingleArg, + useVerticalOCR: undefined as AssertSingleArg, +}; + +// Suppress noUnusedLocals — the type assertion *is* the test. +// eslint-disable-next-line no-void +void _HOOKS_TAKE_SINGLE_ARG; + +describe('Hook props + signature contracts', () => { + it('every *Props type carries preventLoad (compile-time)', () => { + expect(Object.keys(PROPS_TYPES_WITH_PREVENT_LOAD).length).toBeGreaterThan( + 0 + ); + }); + + it('every public hook takes a single object argument (compile-time)', () => { + expect(Object.keys(_HOOKS_TAKE_SINGLE_ARG).length).toBeGreaterThan(0); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/modelRegistry.test.ts b/packages/react-native-executorch/__tests__/api/modelRegistry.test.ts new file mode 100644 index 0000000000..ae94d027c0 --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/modelRegistry.test.ts @@ -0,0 +1,86 @@ +import { models } from '../../src/constants/modelRegistry'; + +type Accessor = (...args: unknown[]) => unknown; + +function isAccessor(v: unknown): v is Accessor { + return typeof v === 'function'; +} + +function walk( + node: unknown, + path: string[], + visit: (path: string[], leaf: Accessor) => void +) { + if (isAccessor(node)) { + visit(path, node); + return; + } + if (node && typeof node === 'object') { + for (const [k, v] of Object.entries(node)) { + walk(v, [...path, k], visit); + } + } +} + +type Entry = { name: string; path: string[]; value: unknown }; + +// Accessors that take required arguments and so can't be invoked with no +// args. Inconsistent with the rest of the registry, but kept as-is for now. +// Listed here so the walker can skip them. +const PARAMETERIZED_ACCESSORS = new Set(['ocr.craft']); + +function collect(): Entry[] { + const out: Entry[] = []; + walk(models, [], (path, accessor) => { + const name = path.join('.'); + if (PARAMETERIZED_ACCESSORS.has(name)) return; + out.push({ name, path, value: accessor() }); + }); + return out; +} + +describe('Model registry', () => { + const entries = collect(); + + it('contains accessors', () => { + expect(entries.length).toBeGreaterThan(0); + }); + + it.each(entries.map((e) => [e.name, e.value] as const))( + '%s returns a non-null object', + (_name, value) => { + expect(value).not.toBeNull(); + expect(typeof value).toBe('object'); + } + ); + + // text_to_speech accessors return TextToSpeechModelConfig (no modelName); + // every other branch returns { modelName, modelSource, ... }. + const standard = entries.filter((e) => e.path[0] !== 'text_to_speech'); + + it.each(standard.map((e) => [e.name, e.value] as const))( + '%s exposes a non-empty modelName', + (_name, value) => { + const v = value as { modelName?: unknown }; + expect(typeof v.modelName).toBe('string'); + expect(v.modelName).not.toBe(''); + } + ); + + it('non-TTS modelNames are unique within each category', () => { + const byCategory = new Map(); + for (const { path, value } of standard) { + const cat = path[0]!; + const modelName = (value as { modelName: string }).modelName; + const bucket = byCategory.get(cat) ?? []; + bucket.push(modelName); + byCategory.set(cat, bucket); + } + const collisions: Array<{ category: string; duplicates: string[] }> = []; + for (const [category, names] of byCategory) { + const duplicates = names.filter((n, i) => names.indexOf(n) !== i); + if (duplicates.length > 0) collisions.push({ category, duplicates }); + } + expect(collisions).toEqual([]); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/modelUrls.test.ts b/packages/react-native-executorch/__tests__/api/modelUrls.test.ts new file mode 100644 index 0000000000..af7e7a1822 --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/modelUrls.test.ts @@ -0,0 +1,81 @@ +import { models } from '../../src/constants/modelRegistry'; +import { URL_PREFIX } from '../../src/constants/versions'; + +type Accessor = (...args: unknown[]) => unknown; + +function isAccessor(v: unknown): v is Accessor { + return typeof v === 'function'; +} + +function walk( + node: unknown, + path: string[], + visit: (path: string[], leaf: Accessor) => void +) { + if (isAccessor(node)) { + visit(path, node); + return; + } + if (node && typeof node === 'object') { + for (const [k, v] of Object.entries(node)) { + walk(v, [...path, k], visit); + } + } +} + +const PARAMETERIZED_ACCESSORS = new Set(['ocr.craft']); + +// Collect every (path, string-valued field) pair from the resolved config of +// every accessor. URL fields are detected by value (starts with "http"), so +// new URL-bearing fields are picked up automatically without per-field opt-in. +type UrlEntry = { path: string; field: string; url: string }; + +function collectUrls(): UrlEntry[] { + const urls: UrlEntry[] = []; + walk(models, [], (path, accessor) => { + const name = path.join('.'); + if (PARAMETERIZED_ACCESSORS.has(name)) return; + const config = accessor(); + collectFromValue(name, config, urls); + }); + return urls; +} + +function collectFromValue(path: string, value: unknown, out: UrlEntry[]) { + if (typeof value === 'string' && /^https?:\/\//.test(value)) { + out.push({ path, field: '', url: value }); + return; + } + if (value && typeof value === 'object') { + for (const [k, v] of Object.entries(value)) { + if (typeof v === 'string' && /^https?:\/\//.test(v)) { + out.push({ path, field: k, url: v }); + } else if (v && typeof v === 'object') { + collectFromValue(`${path}.${k}`, v, out); + } + } + } +} + +describe('Model registry URLs', () => { + const urls = collectUrls(); + + it('contains URL-bearing fields', () => { + expect(urls.length).toBeGreaterThan(0); + }); + + it.each(urls.map((e) => [`${e.path} (${e.field})`, e.url] as const))( + '%s is a non-empty https URL', + (_label, url) => { + expect(url).toMatch(/^https:\/\/\S+$/); + expect(url).not.toBe(''); + } + ); + + it.each(urls.map((e) => [`${e.path} (${e.field})`, e.url] as const))( + '%s points at the software-mansion HuggingFace org', + (_label, url) => { + expect(url.startsWith(URL_PREFIX)).toBe(true); + } + ); +}); diff --git a/packages/react-native-executorch/__tests__/api/moduleConstruction.test.ts b/packages/react-native-executorch/__tests__/api/moduleConstruction.test.ts new file mode 100644 index 0000000000..6e9d5db22e --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/moduleConstruction.test.ts @@ -0,0 +1,224 @@ +import { models } from '../../src/constants/modelRegistry'; +import { + ClassificationModule, + ExecutorchModule, + ImageEmbeddingsModule, + InstanceSegmentationModule, + LLMModule, + OCRModule, + ObjectDetectionModule, + PoseEstimationModule, + PrivacyFilterModule, + ResourceFetcher, + SemanticSegmentationModule, + SpeechToTextModule, + StyleTransferModule, + TextEmbeddingsModule, + TextToImageModule, + TextToSpeechModule, + TokenizerModule, + VADModule, + VerticalOCRModule, +} from '../../src'; + +// Stub adapter: every fetch resolves to a fixed fake path, regardless of how +// many sources are passed. Enough for factories that just thread the path +// into `global.loadXxx` (which is itself stubbed to resolve to `{}`). +function mockAdapter() { + return { + fetch: async ( + _onProgress: (p: number) => void, + ...sources: unknown[] + ): Promise<{ paths: string[]; wasDownloaded: boolean[] }> => ({ + paths: sources.map((_, i) => `/tmp/mock-source-${i}.pte`), + wasDownloaded: sources.map(() => true), + }), + readAsString: async () => '{}', + }; +} + +beforeAll(() => { + ResourceFetcher.setAdapter(mockAdapter()); +}); + +afterAll(() => { + ResourceFetcher.resetAdapter(); +}); + +// Each entry constructs a module via its primary factory using a sample +// config from the registry. The asserted contract is the same for all of +// them: the awaited result is a real instance of the module class and +// `delete()` is callable on it. +// Use `Function` for `ModuleClass` so classes with private constructors +// (Classification, ObjectDetection, …) are accepted. `instanceof` only needs +// a function with a `prototype`. +type Construction = { + name: string; + build: () => Promise<{ delete: () => void }>; + + ModuleClass: Function; +}; + +const constructions: Construction[] = [ + { + name: 'ClassificationModule.fromModelName', + ModuleClass: ClassificationModule, + build: () => + ClassificationModule.fromModelName( + models.classification.efficientnet_v2_s() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'ObjectDetectionModule.fromModelName', + ModuleClass: ObjectDetectionModule, + build: () => + ObjectDetectionModule.fromModelName( + models.object_detection.rf_detr_nano() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'PoseEstimationModule.fromModelName', + ModuleClass: PoseEstimationModule, + build: () => + PoseEstimationModule.fromModelName( + models.pose_estimation.yolo26n() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'SemanticSegmentationModule.fromModelName', + ModuleClass: SemanticSegmentationModule, + build: () => + SemanticSegmentationModule.fromModelName( + models.semantic_segmentation.deeplab_v3_resnet50() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'InstanceSegmentationModule.fromModelName', + ModuleClass: InstanceSegmentationModule, + build: () => + InstanceSegmentationModule.fromModelName( + models.instance_segmentation.yolo26n() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'StyleTransferModule.fromModelName', + ModuleClass: StyleTransferModule, + build: () => + StyleTransferModule.fromModelName( + models.style_transfer.candy() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'ImageEmbeddingsModule.fromModelName', + ModuleClass: ImageEmbeddingsModule, + build: () => + ImageEmbeddingsModule.fromModelName( + models.image_embedding.clip_vit_base_patch32_image() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'TextToImageModule.fromModelName', + ModuleClass: TextToImageModule, + build: () => + TextToImageModule.fromModelName( + models.image_generation.bk_sdm_tiny_vpred_512() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'LLMModule.fromModelName', + ModuleClass: LLMModule, + build: () => + LLMModule.fromModelName(models.llm.qwen3_4b()) as Promise<{ + delete: () => void; + }>, + }, + { + name: 'TextEmbeddingsModule.fromModelName', + ModuleClass: TextEmbeddingsModule, + build: () => + TextEmbeddingsModule.fromModelName( + models.text_embedding.all_minilm_l6_v2() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'PrivacyFilterModule.fromModelName', + ModuleClass: PrivacyFilterModule, + build: () => + PrivacyFilterModule.fromModelName( + models.privacy_filter.openai() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'VADModule.fromModelName', + ModuleClass: VADModule, + build: () => + VADModule.fromModelName(models.vad.fsmn_vad()) as Promise<{ + delete: () => void; + }>, + }, + { + name: 'OCRModule.fromModelName', + ModuleClass: OCRModule, + build: () => + OCRModule.fromModelName(models.ocr.craft({ language: 'en' })) as Promise<{ + delete: () => void; + }>, + }, + { + name: 'VerticalOCRModule.fromModelName', + ModuleClass: VerticalOCRModule, + build: () => + VerticalOCRModule.fromModelName( + models.ocr.craft({ language: 'en' }) + ) as Promise<{ delete: () => void }>, + }, + { + name: 'SpeechToTextModule.fromModelName', + ModuleClass: SpeechToTextModule, + build: () => + SpeechToTextModule.fromModelName( + models.speech_to_text.whisper_tiny_en() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'TextToSpeechModule.fromModelName', + ModuleClass: TextToSpeechModule, + build: () => + TextToSpeechModule.fromModelName( + models.text_to_speech.kokoro.en_us.heart() + ) as Promise<{ delete: () => void }>, + }, + { + name: 'TokenizerModule.fromModelName', + ModuleClass: TokenizerModule, + build: () => + TokenizerModule.fromModelName({ + tokenizerSource: + models.text_embedding.all_minilm_l6_v2().tokenizerSource, + }) as Promise<{ delete: () => void }>, + }, + { + name: 'ExecutorchModule.fromModelSource', + ModuleClass: ExecutorchModule, + build: () => + ExecutorchModule.fromModelSource( + models.text_embedding.all_minilm_l6_v2().modelSource + ) as Promise<{ delete: () => void }>, + }, +]; + +describe('Module construction (mocked native)', () => { + it.each(constructions)( + '$name yields an instance with a callable delete()', + async ({ build, ModuleClass }) => { + const instance = await build(); + expect(instance).toBeInstanceOf(ModuleClass); + expect(typeof instance.delete).toBe('function'); + // Calling delete on the stubbed instance shouldn't throw — the stub + // nativeModule is `{}` and BaseModule.delete is guarded against null + // nativeModule but not against missing `unload`. Modules that rely on + // `nativeModule.unload()` will throw here, which is itself signal. + expect(() => instance.delete()).not.toThrow(); + } + ); +}); diff --git a/packages/react-native-executorch/__tests__/api/moduleContracts.test.ts b/packages/react-native-executorch/__tests__/api/moduleContracts.test.ts new file mode 100644 index 0000000000..657df0ffda --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/moduleContracts.test.ts @@ -0,0 +1,63 @@ +import * as RNE from '../../src'; +import { BaseModule } from '../../src/modules/BaseModule'; + +// Module classes that exist purely as shared bases and have no corresponding +// public hook. Anything not in this set is treated as part of the public API +// surface. +const ABSTRACT_MODULES = new Set([ + 'BaseModule', + 'VisionModule', + 'VisionLabeledModule', +]); + +type ModuleCtor = new (...args: never[]) => unknown; + +function isClassConstructor(value: unknown): value is ModuleCtor { + return ( + typeof value === 'function' && + typeof (value as { prototype?: unknown }).prototype === 'object' && + (value as { prototype: { constructor?: unknown } }).prototype + .constructor === value + ); +} + +function getModuleClasses(): Array<[string, ModuleCtor]> { + return Object.entries(RNE).filter( + ([name, value]) => + name.endsWith('Module') && + !name.startsWith('use') && + !ABSTRACT_MODULES.has(name) && + isClassConstructor(value) + ) as Array<[string, ModuleCtor]>; +} + +describe('Module contracts', () => { + const modules = getModuleClasses(); + + it('exports at least one concrete Module class', () => { + expect(modules.length).toBeGreaterThan(0); + }); + + describe.each(modules)('%s', (_name, ModuleClass) => { + it('extends BaseModule', () => { + expect(ModuleClass.prototype instanceof BaseModule).toBe(true); + }); + + it('declares at least one static factory method (from*)', () => { + const factories = Object.getOwnPropertyNames(ModuleClass).filter( + (n) => + n.startsWith('from') && + typeof (ModuleClass as unknown as Record)[n] === + 'function' + ); + expect(factories.length).toBeGreaterThan(0); + }); + }); + + it.each(modules)('%s has a corresponding hook export', (name) => { + const expected = 'use' + name.replace(/Module$/, ''); + const hook = (RNE as unknown as Record)[expected]; + expect(hook).toBeDefined(); + expect(typeof hook).toBe('function'); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/moduleHookSignatureAlignment.test.ts b/packages/react-native-executorch/__tests__/api/moduleHookSignatureAlignment.test.ts new file mode 100644 index 0000000000..686ed0515d --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/moduleHookSignatureAlignment.test.ts @@ -0,0 +1,166 @@ +import type { + ExecutorchModule, + ExecutorchModuleType, + ImageEmbeddingsModule, + ImageEmbeddingsType, + LLMModule, + LLMType, + StyleTransferModule, + StyleTransferType, + TextEmbeddingsModule, + TextEmbeddingsType, + TextToImageModule, + TextToImageType, + TokenizerModule, + TokenizerType, + VADModule, + VADType, +} from '../../src'; + +// Compile-time alignment between every non-generic module's primary +// inference method(s) and the matching hook return type's method(s). +// +// The hook wrappers around each module are thin (`(...args) => +// runForward((inst) => inst.method(...args))`), so a Hook → Module signature +// mismatch means the hook silently advertises a narrower or wider surface +// than the module actually supports. The assertions below run via tsc and +// flag any drift naming the (module, method) pair. +// +// Modules with class-level generics (Classification, ObjectDetection, +// PoseEstimation, Semantic/InstanceSegmentation, VerticalOCR) are left out +// of this file because their hook return shape and module prototype shape +// depend on per-call type parameters that don't survive `Parameters<>` / +// `ReturnType<>` extraction. Their alignment is exercised at runtime in +// moduleConstruction.test.ts. + +type EqualParam = Parameters< + F extends (...a: never[]) => unknown ? F : never +>[0] extends Parameters unknown ? G : never>[0] + ? Parameters< + G extends (...a: never[]) => unknown ? G : never + >[0] extends Parameters unknown ? F : never>[0] + ? true + : { + ERROR: 'module accepts inputs the hook does not advertise'; + moduleParam: Parameters< + F extends (...a: never[]) => unknown ? F : never + >[0]; + hookParam: Parameters< + G extends (...a: never[]) => unknown ? G : never + >[0]; + } + : { + ERROR: 'hook accepts inputs the module does not'; + moduleParam: Parameters< + F extends (...a: never[]) => unknown ? F : never + >[0]; + hookParam: Parameters< + G extends (...a: never[]) => unknown ? G : never + >[0]; + }; + +type EqualReturn = + Awaited< + ReturnType unknown ? F : never> + > extends Awaited< + ReturnType unknown ? G : never> + > + ? Awaited< + ReturnType unknown ? G : never> + > extends Awaited< + ReturnType unknown ? F : never> + > + ? true + : { + ERROR: 'module returns more than the hook advertises'; + } + : { ERROR: 'hook returns more than the module produces' }; + +// For each (module, method, hook field) row, both an input-shape and a +// return-shape equality is asserted. Any breakage shows up as the +// satisfies-clause failing with one of the labelled error types above. +const _ALIGNMENT = { + // ExecutorchModule has no `forward` wrapper on its hook return — the hook + // returns the instance's `forward` (Tensor I/O) directly. + executorchModule_forward: { + inputs: true as EqualParam< + ExecutorchModule['forward'], + ExecutorchModuleType['forward'] + >, + returns: true as EqualReturn< + ExecutorchModule['forward'], + ExecutorchModuleType['forward'] + >, + }, + imageEmbeddings_forward: { + inputs: true as EqualParam< + ImageEmbeddingsModule['forward'], + ImageEmbeddingsType['forward'] + >, + returns: true as EqualReturn< + ImageEmbeddingsModule['forward'], + ImageEmbeddingsType['forward'] + >, + }, + // LLM's primary method is `generate` (not `forward`) because it is a + // streaming autoregressive text-generation API — matching the HuggingFace + // transformers / llama.cpp / OpenAI convention — rather than a single-pass + // tensor I/O call. Both the module method and the hook return field are + // named `generate` consistently, so the alignment check still holds. + llm_generate: { + inputs: true as EqualParam, + returns: true as EqualReturn, + }, + styleTransfer_forward: { + inputs: true as EqualParam< + StyleTransferModule['forward'], + StyleTransferType['forward'] + >, + returns: true as EqualReturn< + StyleTransferModule['forward'], + StyleTransferType['forward'] + >, + }, + textEmbeddings_forward: { + inputs: true as EqualParam< + TextEmbeddingsModule['forward'], + TextEmbeddingsType['forward'] + >, + returns: true as EqualReturn< + TextEmbeddingsModule['forward'], + TextEmbeddingsType['forward'] + >, + }, + textToImage_forward: { + inputs: true as EqualParam< + TextToImageModule['forward'], + TextToImageType['forward'] + >, + returns: true as EqualReturn< + TextToImageModule['forward'], + TextToImageType['forward'] + >, + }, + tokenizer_encode: { + inputs: true as EqualParam< + TokenizerModule['encode'], + TokenizerType['encode'] + >, + returns: true as EqualReturn< + TokenizerModule['encode'], + TokenizerType['encode'] + >, + }, + vad_forward: { + inputs: true as EqualParam, + returns: true as EqualReturn, + }, +}; +// eslint-disable-next-line no-void +void _ALIGNMENT; + +describe('Module ↔ hook signature alignment', () => { + it('every checked module method aligns with its hook return field (compile-time)', () => { + expect(typeof _ALIGNMENT).toBe('object'); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/modulePrototype.test.ts b/packages/react-native-executorch/__tests__/api/modulePrototype.test.ts new file mode 100644 index 0000000000..bb0e217ca9 --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/modulePrototype.test.ts @@ -0,0 +1,82 @@ +import * as RNE from '../../src'; +import { BaseModule } from '../../src/modules/BaseModule'; + +// Mirror the abstract-module set from moduleContracts.test.ts. +const ABSTRACT_MODULES = new Set([ + 'BaseModule', + 'VisionModule', + 'VisionLabeledModule', +]); + +type ModuleClass = new (...args: never[]) => unknown; + +function isClassConstructor(value: unknown): value is ModuleClass { + return ( + typeof value === 'function' && + typeof (value as { prototype?: unknown }).prototype === 'object' && + (value as { prototype: { constructor?: unknown } }).prototype + .constructor === value + ); +} + +function getModuleClasses(): Array<[string, ModuleClass]> { + return Object.entries(RNE).filter( + ([name, value]) => + name.endsWith('Module') && + !name.startsWith('use') && + !ABSTRACT_MODULES.has(name) && + isClassConstructor(value) + ) as Array<[string, ModuleClass]>; +} + +// Walk the prototype chain (excluding Object.prototype) and collect every +// non-constructor, non-private callable surface name. Uses property +// descriptors rather than direct access so accessor properties (getters such +// as VisionModule.runOnFrame) are counted without being invoked — invoking +// them on the prototype with no native module loaded would throw. +function reachablePublicMethods(ModuleClass: ModuleClass): Set { + const out = new Set(); + let proto: object | null = ModuleClass.prototype; + while (proto && proto !== Object.prototype) { + for (const name of Object.getOwnPropertyNames(proto)) { + if (name === 'constructor') continue; + if (name.startsWith('_')) continue; + const desc = Object.getOwnPropertyDescriptor(proto, name); + if (!desc) continue; + if (typeof desc.value === 'function' || typeof desc.get === 'function') { + out.add(name); + } + } + proto = Object.getPrototypeOf(proto); + } + return out; +} + +describe('Module prototype surface', () => { + const modules = getModuleClasses(); + + it.each(modules)( + '%s exposes at least one public instance method on the prototype chain', + (_name, ModuleClass) => { + const methods = reachablePublicMethods(ModuleClass); + expect(methods.size).toBeGreaterThan(0); + } + ); + + it.each(modules)( + '%s has a reachable delete() method', + (_name, ModuleClass) => { + const methods = reachablePublicMethods(ModuleClass); + expect(methods.has('delete')).toBe(true); + } + ); + + it('BaseModule itself exposes the documented base surface', () => { + const surface = Object.getOwnPropertyNames(BaseModule.prototype).sort(); + // Stable, intentionally tiny. If BaseModule grows, the diff makes the + // intent explicit; if a method is renamed accidentally, this fails. + expect(surface).toEqual( + ['constructor', 'delete', 'forwardET', 'getInputShape'].sort() + ); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/registryHookCompatibility.test.ts b/packages/react-native-executorch/__tests__/api/registryHookCompatibility.test.ts new file mode 100644 index 0000000000..2a06bdd11e --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/registryHookCompatibility.test.ts @@ -0,0 +1,66 @@ +import { models } from '../../src/constants/modelRegistry'; +import type { + ClassificationModelSources, + ImageEmbeddingsProps, + InstanceSegmentationModelSources, + LLMProps, + ObjectDetectionModelSources, + OCRProps, + PoseEstimationModelSources, + PrivacyFilterProps, + SemanticSegmentationModelSources, + SpeechToTextProps, + StyleTransferProps, + TextEmbeddingsProps, + TextToImageProps, + TextToSpeechModelConfig, + VADProps, +} from '../../src'; + +// Compile-time assertion: every registry accessor returns a config that is +// assignable to the corresponding hook's `model` prop type. If the registry +// drifts from the hook prop shape, tsc errors here naming the offending +// (accessor → prop) pair. +// +// One sample per category is enough — all accessors in a category go through +// the same `base`/`pair`/`variant` builders so their static return types are +// structurally identical. Add a row only when a new category lands. +// +// Generic hook props (ClassificationProps, etc.) wrap a source-of-truth +// `XxxModelSources` type, and the props' `model` field is `C`. We assert +// against the unwrapped `XxxModelSources` directly so the generic constraint +// can't collapse to `never`. + +function _assertRegistryAssignability() { + // computer vision + models.classification.efficientnet_v2_s() satisfies ClassificationModelSources; + models.object_detection.rf_detr_nano() satisfies ObjectDetectionModelSources; + models.pose_estimation.yolo26n() satisfies PoseEstimationModelSources; + models.semantic_segmentation.deeplab_v3_resnet50() satisfies SemanticSegmentationModelSources; + models.instance_segmentation.yolo26n() satisfies InstanceSegmentationModelSources; + models.style_transfer.candy() satisfies StyleTransferProps['model']; + models.image_embedding.clip_vit_base_patch32_image() satisfies ImageEmbeddingsProps['model']; + models.image_generation.bk_sdm_tiny_vpred_512() satisfies TextToImageProps['model']; + models.ocr.craft({ language: 'en' }) satisfies OCRProps['model']; + + // natural language processing + models.llm.qwen3_4b() satisfies LLMProps['model']; + models.privacy_filter.openai() satisfies PrivacyFilterProps['model']; + models.speech_to_text.whisper_tiny_en() satisfies SpeechToTextProps['model']; + models.text_embedding.all_minilm_l6_v2() satisfies TextEmbeddingsProps['model']; + models.vad.fsmn_vad() satisfies VADProps['model']; + + // TTS leafs return a TextToSpeechModelConfig directly (no `model:` wrapper + // — useTextToSpeech is the outlier that takes the config as a positional + // arg, tracked in #1202). + models.text_to_speech.kokoro.en_us.heart() satisfies TextToSpeechModelConfig; +} +// eslint-disable-next-line no-void +void _assertRegistryAssignability; + +describe('Registry → hook prop compatibility', () => { + it('every category sample is assignable to its hook prop (compile-time)', () => { + // The real assertion is the `satisfies` clause above, checked by tsc. + expect(typeof _assertRegistryAssignability).toBe('function'); + }); +}); diff --git a/packages/react-native-executorch/__tests__/api/ttsVoices.test.ts b/packages/react-native-executorch/__tests__/api/ttsVoices.test.ts new file mode 100644 index 0000000000..126e6e9a5a --- /dev/null +++ b/packages/react-native-executorch/__tests__/api/ttsVoices.test.ts @@ -0,0 +1,90 @@ +import * as Voices from '../../src/constants/tts/voices'; +import { URL_PREFIX } from '../../src/constants/versions'; + +// Voice variable-name region prefix → expected `phonemizerConfig.lang`. A +// voice constant exported under e.g. `KOKORO_FRENCH_*` is expected to carry +// `lang: 'fr'`. A mismatch is almost always a copy-paste bug, so we keep the +// map narrow and explicit. +const REGION_TO_LANG: Record = { + AMERICAN_ENGLISH: 'en-us', + BRITISH_ENGLISH: 'en-gb', + FRENCH: 'fr', + SPANISH: 'es', + ITALIAN: 'it', + PORTUGUESE: 'pt', + HINDI: 'hi', + POLISH: 'pl', + GERMAN: 'de', +}; + +type VoiceConfig = { + voiceSource: string; + phonemizerConfig: { + lang: string; + taggerSource?: string; + lexiconSource?: string; + neuralModelSource?: string; + }; + model: { modelName?: string }; +}; + +function regionOf(name: string): string | null { + for (const region of Object.keys(REGION_TO_LANG)) { + if (name.startsWith(`KOKORO_${region}_`)) return region; + } + return null; +} + +function getVoiceEntries(): Array<[string, VoiceConfig]> { + return Object.entries(Voices) + .filter(([name]) => name.startsWith('KOKORO_')) + .map(([name, value]) => [name, value as VoiceConfig]); +} + +describe('Kokoro voices', () => { + const voices = getVoiceEntries(); + + it('exports voices', () => { + expect(voices.length).toBeGreaterThan(0); + }); + + it.each(voices)('%s has a known region prefix', (name) => { + expect(regionOf(name)).not.toBeNull(); + }); + + it.each(voices)( + '%s phonemizerConfig.lang matches its region prefix', + (name, voice) => { + const region = regionOf(name); + if (!region) throw new Error(`No region for ${name}`); + expect(voice.phonemizerConfig.lang).toBe(REGION_TO_LANG[region]); + } + ); + + it.each(voices)( + '%s voiceSource points at the Kokoro voices directory', + (_name, voice) => { + expect(voice.voiceSource.startsWith(URL_PREFIX)).toBe(true); + expect(voice.voiceSource).toMatch(/\/voices\/[^/]+\.bin$/); + } + ); + + it.each(voices)( + '%s phonemizer URLs all live under the voice language directory', + (_name, voice) => { + const { lang, taggerSource, lexiconSource, neuralModelSource } = + voice.phonemizerConfig; + const expectedSegment = `/phonemizer/${lang}/`; + for (const url of [taggerSource, lexiconSource, neuralModelSource]) { + if (url === undefined) continue; + expect(url.startsWith(URL_PREFIX)).toBe(true); + expect(url).toContain(expectedSegment); + } + } + ); + + it.each(voices)('%s references a model with a modelName', (_name, voice) => { + expect(typeof voice.model.modelName).toBe('string'); + expect(voice.model.modelName?.length ?? 0).toBeGreaterThan(0); + }); +}); diff --git a/packages/react-native-executorch/__tests__/mocks/react-native.ts b/packages/react-native-executorch/__tests__/mocks/react-native.ts new file mode 100644 index 0000000000..6b2062b488 --- /dev/null +++ b/packages/react-native-executorch/__tests__/mocks/react-native.ts @@ -0,0 +1,21 @@ +// Minimal mock for the bits of `react-native` that the package imports at +// module-load time during these contract tests. Extend as new APIs are +// reached. + +export const Platform = { + OS: 'ios' as 'ios' | 'android' | 'web', + select: (specifics: { + ios?: T; + android?: T; + default?: T; + }): T | undefined => specifics.ios ?? specifics.default, +}; + +export const NativeModules: Record = {}; + +export const TurboModuleRegistry = { + get: () => null, + getEnforcing: () => { + throw new Error('TurboModuleRegistry not available in test env'); + }, +}; diff --git a/packages/react-native-executorch/__tests__/setup-globals.ts b/packages/react-native-executorch/__tests__/setup-globals.ts new file mode 100644 index 0000000000..d16117b2f0 --- /dev/null +++ b/packages/react-native-executorch/__tests__/setup-globals.ts @@ -0,0 +1,40 @@ +// src/index.ts checks for `global.loadXxx` JSI bindings and, if any are missing, +// calls into the native ETInstaller to install them. In Jest there are no JSI +// bindings, so we stub them out here to keep the import path side-effect-free. + +// Each `loadXxx` resolves to a minimal native-module stub that includes the +// methods modules consistently call: `unload` (for BaseModule.delete) and +// `generateFromFrame` (for VisionModule's worklet getter). Modules that need +// more can replace the stub in their own test. +const stub = (() => + Promise.resolve({ + unload: () => {}, + generateFromFrame: () => {}, + })) as unknown as () => Promise; +const g = globalThis as unknown as Record; + +const JSI_GLOBALS = [ + 'loadStyleTransfer', + 'loadSemanticSegmentation', + 'loadInstanceSegmentation', + 'loadTextToImage', + 'loadExecutorchModule', + 'loadClassification', + 'loadObjectDetection', + 'loadPoseEstimation', + 'loadTokenizerModule', + 'loadTextEmbeddings', + 'loadImageEmbeddings', + 'loadVAD', + 'loadLLM', + 'loadPrivacyFilter', + 'loadSpeechToText', + 'loadTextToSpeechKokoro', + 'loadOCR', + 'loadVerticalOCR', +]; + +for (const name of JSI_GLOBALS) { + g[name] = stub; +} +g.__rne_isEmulator = false; diff --git a/packages/react-native-executorch/jest.config.js b/packages/react-native-executorch/jest.config.js new file mode 100644 index 0000000000..bdd2b8e3da --- /dev/null +++ b/packages/react-native-executorch/jest.config.js @@ -0,0 +1,25 @@ +module.exports = { + rootDir: __dirname, + testEnvironment: 'node', + testMatch: ['/__tests__/**/*.test.ts?(x)'], + setupFiles: ['/__tests__/setup-globals.ts'], + moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json'], + transform: { + '^.+\\.(ts|tsx|js|jsx)$': [ + 'babel-jest', + { + babelrc: false, + configFile: false, + presets: [ + ['@babel/preset-env', { targets: { node: 'current' } }], + '@babel/preset-typescript', + ['@babel/preset-react', { runtime: 'automatic' }], + ], + }, + ], + }, + transformIgnorePatterns: ['/node_modules/(?!(@huggingface)/)'], + moduleNameMapper: { + '^react-native$': '/__tests__/mocks/react-native.ts', + }, +}; diff --git a/packages/react-native-executorch/package.json b/packages/react-native-executorch/package.json index 2aceb63d1f..2cfc4fdf84 100644 --- a/packages/react-native-executorch/package.json +++ b/packages/react-native-executorch/package.json @@ -35,6 +35,8 @@ "scripts": { "example": "yarn workspace react-native-executorch-example", "typecheck": "tsc --noEmit", + "typecheck:tests": "tsc --noEmit -p tsconfig.test.json", + "test": "jest", "lint": "eslint \"**/*.{js,ts,tsx}\"", "clean": "del-cli android/build example/android/build example/android/app/build example/ios/build lib", "prepare": "bob build", diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts b/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts index f5ca97fd53..5da8a9de28 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts @@ -77,7 +77,7 @@ export const useTextToImage = ({ preventLoad, ]); - const generate = async ( + const forward = async ( input: string, imageSize?: number, numSteps?: number, @@ -106,7 +106,8 @@ export const useTextToImage = ({ isGenerating, downloadProgress, error, - generate, + forward, + generate: forward, interrupt, }; }; diff --git a/packages/react-native-executorch/src/hooks/general/useExecutorch.ts b/packages/react-native-executorch/src/hooks/general/useExecutorch.ts new file mode 100644 index 0000000000..0cc1f8fdfb --- /dev/null +++ b/packages/react-native-executorch/src/hooks/general/useExecutorch.ts @@ -0,0 +1,39 @@ +import { ExecutorchModule } from '../../modules/general/ExecutorchModule'; +import { + ExecutorchModuleProps, + ExecutorchModuleType, +} from '../../types/executorchModule'; +import { TensorPtr } from '../../types/common'; +import { useModuleFactory } from '../useModuleFactory'; + +/** + * React hook for managing an arbitrary Executorch module instance. + * @category Hooks + * @param props - Configuration object containing `modelSource` and optional `preventLoad` flag. + * @returns Ready to use Executorch module. + */ +export const useExecutorch = ({ + modelSource, + preventLoad = false, +}: ExecutorchModuleProps): ExecutorchModuleType => { + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (source, onProgress) => + ExecutorchModule.fromModelSource(source, onProgress), + config: modelSource, + deps: [modelSource], + preventLoad, + }); + + const forward = (inputTensor: TensorPtr[]) => + runForward((inst) => inst.forward(inputTensor)); + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; + +/** + * @deprecated Use `useExecutorch` instead. `useExecutorchModule` is kept as a + * temporary alias for backward compatibility and will be removed in a future + * release. + */ +export const useExecutorchModule = useExecutorch; diff --git a/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts b/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts deleted file mode 100644 index f41524fc4d..0000000000 --- a/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { ExecutorchModule } from '../../modules/general/ExecutorchModule'; -import { - ExecutorchModuleProps, - ExecutorchModuleType, -} from '../../types/executorchModule'; -import { useModule } from '../useModule'; - -/** - * React hook for managing an arbitrary Executorch module instance. - * @category Hooks - * @param executorchModuleProps - Configuration object containing `modelSource` and optional `preventLoad` flag. - * @returns Ready to use Executorch module. - */ -export const useExecutorchModule = ({ - modelSource, - preventLoad = false, -}: ExecutorchModuleProps): ExecutorchModuleType => - useModule({ - module: ExecutorchModule, - model: modelSource, - preventLoad, - }); diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts index 53ca98b0b1..409f217dcb 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts @@ -3,24 +3,41 @@ import { TextToSpeechModule } from '../../modules/natural_language_processing/Te import { TextToSpeechInput, TextToSpeechModelConfig, + TextToSpeechProps, TextToSpeechStreamingInput, TextToSpeechType, } from '../../types/tts'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; +/** + * @deprecated Pass a single object argument: + * `useTextToSpeech({ model, preventLoad })`. The two-argument form is kept as + * a temporary alias for backward compatibility and will be removed in a + * future release. + * @param model - The Kokoro voice / model bundle to load. + * @param options - Optional flags; currently only `preventLoad`. + * @returns Ready to use Text to Speech model. + */ +export function useTextToSpeech( + model: TextToSpeechModelConfig, + options?: { preventLoad?: boolean } +): TextToSpeechType; /** * React hook for managing Text to Speech instance. * @category Hooks - * @param model - Configuration object containing model config. - * @param options - Additional options for the hook. - * @param options.preventLoad - If true, prevents the model from loading automatically on initialization. + * @param props - Configuration object containing `model` (voice + Kokoro bundle) and optional `preventLoad` flag. * @returns Ready to use Text to Speech model. */ -export const useTextToSpeech = ( - model: TextToSpeechModelConfig, - { preventLoad = false }: { preventLoad?: boolean } = {} -): TextToSpeechType => { +export function useTextToSpeech(props: TextToSpeechProps): TextToSpeechType; +export function useTextToSpeech( + arg1: TextToSpeechProps | TextToSpeechModelConfig, + arg2?: { preventLoad?: boolean } +): TextToSpeechType { + const { model, preventLoad = false }: TextToSpeechProps = + 'voiceSource' in arg1 + ? { model: arg1, preventLoad: arg2?.preventLoad } + : arg1; const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); @@ -160,4 +177,4 @@ export const useTextToSpeech = ( streamStop, downloadProgress, }; -}; +} diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts index d1ca99c1a4..365bb08c67 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts @@ -1,70 +1,36 @@ -import { useEffect, useState } from 'react'; import { TokenizerModule } from '../../modules/natural_language_processing/TokenizerModule'; -import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; -import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; import { TokenizerProps, TokenizerType } from '../../types/tokenizer'; +import { useModuleFactory } from '../useModuleFactory'; /** * React hook for managing a Tokenizer instance. * @category Hooks - * @param tokenizerProps - Configuration object containing `tokenizer` source and optional `preventLoad` flag. + * @param props - Configuration object containing `tokenizer` source and optional `preventLoad` flag. * @returns Ready to use Tokenizer model. */ export const useTokenizer = ({ tokenizer, preventLoad = false, }: TokenizerProps): TokenizerType => { - const [error, setError] = useState(null); - const [isReady, setIsReady] = useState(false); - const [isGenerating, setIsGenerating] = useState(false); - const [downloadProgress, setDownloadProgress] = useState(0); - const [tokenizerInstance] = useState(() => new TokenizerModule()); - - useEffect(() => { - if (preventLoad) return; - (async () => { - setDownloadProgress(0); - setError(null); - try { - setIsReady(false); - await tokenizerInstance.load( - { tokenizerSource: tokenizer.tokenizerSource }, - setDownloadProgress - ); - setIsReady(true); - } catch (err) { - setError(parseUnknownError(err)); - } - })(); - }, [tokenizerInstance, tokenizer.tokenizerSource, preventLoad]); - - const stateWrapper = Promise>(fn: T) => { - return (...args: Parameters): Promise>> => { - if (!isReady) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - 'The model is currently not loaded. Please load the model before calling this function.' - ); - if (isGenerating) - throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating); - try { - setIsGenerating(true); - return fn.apply(tokenizerInstance, args); - } finally { - setIsGenerating(false); - } - }; - }; + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (config, onProgress) => + TokenizerModule.fromModelName(config, onProgress), + config: { tokenizerSource: tokenizer.tokenizerSource }, + deps: [tokenizer.tokenizerSource], + preventLoad, + }); return { error, isReady, isGenerating, downloadProgress, - decode: stateWrapper(TokenizerModule.prototype.decode), - encode: stateWrapper(TokenizerModule.prototype.encode), - getVocabSize: stateWrapper(TokenizerModule.prototype.getVocabSize), - idToToken: stateWrapper(TokenizerModule.prototype.idToToken), - tokenToId: stateWrapper(TokenizerModule.prototype.tokenToId), + decode: (tokens, skipSpecialTokens) => + runForward((inst) => inst.decode(tokens, skipSpecialTokens)), + encode: (input) => runForward((inst) => inst.encode(input)), + getVocabSize: () => runForward((inst) => inst.getVocabSize()), + idToToken: (tokenId) => runForward((inst) => inst.idToToken(tokenId)), + tokenToId: (token) => runForward((inst) => inst.tokenToId(token)), }; }; diff --git a/packages/react-native-executorch/src/hooks/useModule.ts b/packages/react-native-executorch/src/hooks/useModule.ts deleted file mode 100644 index c26c93c361..0000000000 --- a/packages/react-native-executorch/src/hooks/useModule.ts +++ /dev/null @@ -1,148 +0,0 @@ -import { useEffect, useState } from 'react'; -import { RnExecutorchErrorCode } from '../errors/ErrorCodes'; -import { RnExecutorchError, parseUnknownError } from '../errors/errorUtils'; - -type RunOnFrame = M extends { runOnFrame: infer R } ? R : never; - -interface Module { - load: (...args: any[]) => Promise; - forward: (...args: any[]) => Promise; - delete: () => void; -} - -interface ModuleConstructor { - new (): M; -} - -export const useModule = < - M extends Module, - LoadArgs extends Parameters, - ForwardArgs extends Parameters, - ForwardReturn extends Awaited>, ->({ - module, - model, - preventLoad = false, -}: { - module: ModuleConstructor; - model: LoadArgs[0]; - preventLoad?: boolean; -}) => { - const [error, setError] = useState(null); - const [isReady, setIsReady] = useState(false); - const [isGenerating, setIsGenerating] = useState(false); - const [downloadProgress, setDownloadProgress] = useState(0); - const [moduleInstance] = useState(() => new module()); - const [runOnFrame, setRunOnFrame] = useState | null>(null); - - useEffect(() => { - if (preventLoad) return; - - let isMounted = true; - - (async () => { - setDownloadProgress(0); - setError(null); - try { - setIsReady(false); - await moduleInstance.load(model, (progress: number) => { - if (isMounted) setDownloadProgress(progress); - }); - if (isMounted) setIsReady(true); - - // VisionCamera worklets run on a separate JS thread and can only capture - // serializable values (plain functions, primitives). The module instance - // is a class object and is not serializable, so accessing runOnFrame - // directly inside a worklet would fail at runtime. - // - // By extracting the method and storing it in React state, it becomes a - // standalone function reference that the worklet thread can capture and - // call safely. - // - // Note: setState(fn) triggers React's updater form — it calls fn(prevState) - // and stores the return value, not fn itself. Since runOnFrame is a function, - // we wrap it: setState(() => worklet) so React stores the worklet as the - // state value rather than invoking it. - if ('runOnFrame' in moduleInstance) { - const worklet = moduleInstance.runOnFrame as RunOnFrame; - if (worklet) { - setRunOnFrame(() => worklet); - } - } - } catch (err) { - if (isMounted) setError(parseUnknownError(err)); - } - })(); - - return () => { - isMounted = false; - setIsReady(false); - setRunOnFrame(null); - moduleInstance.delete(); - }; - - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [moduleInstance, ...Object.values(model), preventLoad]); - - const forward = async (...input: ForwardArgs): Promise => { - if (!isReady) - throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded); - if (isGenerating) - throw new RnExecutorchError(RnExecutorchErrorCode.ModelGenerating); - try { - setIsGenerating(true); - return await moduleInstance.forward(...input); - } finally { - setIsGenerating(false); - } - }; - - return { - /** - * Contains the error message if the model failed to load. - */ - error, - - /** - * Indicates whether the model is ready. - */ - isReady, - - /** - * Indicates whether the model is currently generating a response. - */ - isGenerating, - - /** - * Represents the download progress as a value between 0 and 1, indicating the extent of the model file retrieval. - */ - downloadProgress, - forward, - - /** - * Synchronous worklet function for real-time VisionCamera frame processing. - * Automatically handles native buffer extraction and cleanup. - * - * Only available for Computer Vision modules that support real-time frame processing - * (e.g., ObjectDetection, Classification, ImageSegmentation). - * Returns `null` if the module doesn't implement frame processing. - * - * **Use this for VisionCamera frame processing in worklets.** - * For async processing, use `forward()` instead. - * @example - * ```typescript - * const { runOnFrame } = useObjectDetection({ model: MODEL }); - * - * const frameOutput = useFrameOutput({ - * onFrame(frame) { - * 'worklet'; - * if (!runOnFrame) return; - * const detections = runOnFrame(frame, 0.5); - * frame.dispose(); - * } - * }); - * ``` - */ - runOnFrame, - }; -}; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 1f190d41f5..0c4c71bb02 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -183,7 +183,7 @@ export * from './hooks/natural_language_processing/useTextEmbeddings'; export * from './hooks/natural_language_processing/useTokenizer'; export * from './hooks/natural_language_processing/useVAD'; -export * from './hooks/general/useExecutorchModule'; +export * from './hooks/general/useExecutorch'; // modules export * from './modules/computer_vision/ClassificationModule'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts index 69faca6980..06f4da8e01 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts @@ -3,15 +3,17 @@ import { ResourceSource } from '../../types/common'; import { OCRDetection, OCRLanguage, OCRModelName } from '../../types/ocr'; import { Logger } from '../../common/Logger'; import { parseUnknownError } from '../../errors/errorUtils'; +import { BaseModule } from '../BaseModule'; /** * Module for Optical Character Recognition (OCR) tasks. * @category Typescript API */ -export class OCRModule { +export class OCRModule extends BaseModule { private controller: OCRController; private constructor(controller: OCRController) { + super(); this.controller = controller; } diff --git a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts index f9a818d4dc..e81e6b309f 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts @@ -3,15 +3,17 @@ import { VerticalOCRController } from '../../controllers/VerticalOCRController'; import { parseUnknownError } from '../../errors/errorUtils'; import { ResourceSource } from '../../types/common'; import { OCRDetection, OCRLanguage, OCRModelName } from '../../types/ocr'; +import { BaseModule } from '../BaseModule'; /** * Module for Vertical Optical Character Recognition (Vertical OCR) tasks. * @category Typescript API */ -export class VerticalOCRModule { +export class VerticalOCRModule extends BaseModule { private controller: VerticalOCRController; private constructor(controller: VerticalOCRController) { + super(); this.controller = controller; } diff --git a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts index ed98b92698..ea66958638 100644 --- a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts +++ b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts @@ -11,25 +11,31 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class ExecutorchModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } + /** - * Loads the model, where `modelSource` is a string, number, or object that specifies the location of the model binary. - * Optionally accepts a download progress callback. + * Creates an Executorch instance from a model binary. * @param modelSource - Source of the model to be loaded. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ExecutorchModule` instance. */ - async load( + static async fromModelSource( modelSource: ResourceSource, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, + onDownloadProgress, modelSource ); if (!paths?.[0]) { throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted); } - this.nativeModule = await global.loadExecutorchModule(paths[0]); + const nativeModule = await global.loadExecutorchModule(paths[0]); + return new ExecutorchModule(nativeModule); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts index bdb5ada699..3c2dc2bf2d 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts @@ -9,12 +9,13 @@ import { LLMTool, Message, } from '../../types/llm'; +import { BaseModule } from '../BaseModule'; /** * Module for managing a Large Language Model (LLM) instance. * @category Typescript API */ -export class LLMModule { +export class LLMModule extends BaseModule { private controller: LLMController; private constructor({ @@ -24,6 +25,7 @@ export class LLMModule { tokenCallback?: (token: string) => void; messageHistoryCallback?: (messageHistory: Message[]) => void; } = {}) { + super(); this.controller = new LLMController({ tokenCallback, messageHistoryCallback, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts index f7ffe52f1c..81709c0c49 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts @@ -11,19 +11,20 @@ import { ResourceSource } from '../../types/common'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; +import { BaseModule } from '../BaseModule'; /** * Module for Speech to Text (STT) functionalities. * @category Typescript API */ -export class SpeechToTextModule { - private nativeModule: any; +export class SpeechToTextModule extends BaseModule { private modelConfig: SpeechToTextModelConfig; private constructor( nativeModule: unknown, modelConfig: SpeechToTextModelConfig ) { + super(); this.nativeModule = nativeModule; this.modelConfig = modelConfig; } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts index 6b6695f1f7..7d548e5f8b 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts @@ -8,16 +8,17 @@ import { TextToSpeechStreamingInput, } from '../../types/tts'; import { Logger } from '../../common/Logger'; +import { BaseModule } from '../BaseModule'; /** * Module for Text to Speech (TTS) functionalities. * @category Typescript API */ -export class TextToSpeechModule { - private nativeModule: any; +export class TextToSpeechModule extends BaseModule { private isStreaming: boolean = false; private constructor(nativeModule: unknown) { + super(); this.nativeModule = nativeModule; } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts index aec377b91e..9ddf4b5087 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts @@ -3,37 +3,39 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { Logger } from '../../common/Logger'; +import { BaseModule } from '../BaseModule'; /** * Module for Tokenizer functionalities. * @category Typescript API */ -export class TokenizerModule { - /** - * Native module instance - */ - nativeModule: any; +export class TokenizerModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } /** - * Loads the tokenizer from the specified source. - * `tokenizerSource` is a string that points to the location of the tokenizer JSON file. - * @param tokenizer - Object containing `tokenizerSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * Creates a Tokenizer instance for the provided tokenizer JSON source. + * @param namedSources - Object containing `tokenizerSource` — a fetchable resource pointing at the tokenizer JSON. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `TokenizerModule` instance. */ - async load( - tokenizer: { tokenizerSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + static async fromModelName( + namedSources: { tokenizerSource: ResourceSource }, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - tokenizer.tokenizerSource + onDownloadProgress, + namedSources.tokenizerSource ); const path = paths?.[0]; if (!path) { throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted); } - this.nativeModule = await global.loadTokenizerModule(path); + const nativeModule = await global.loadTokenizerModule(path); + return new TokenizerModule(nativeModule); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); diff --git a/packages/react-native-executorch/src/types/executorchModule.ts b/packages/react-native-executorch/src/types/executorchModule.ts index e058d13eb7..a6ea826ca3 100644 --- a/packages/react-native-executorch/src/types/executorchModule.ts +++ b/packages/react-native-executorch/src/types/executorchModule.ts @@ -2,7 +2,7 @@ import { ResourceSource, TensorPtr } from '../types/common'; import { RnExecutorchError } from '../errors/errorUtils'; /** - * Props for the `useExecutorchModule` hook. + * Props for the `useExecutorch` hook. * @category Types * @property {ResourceSource} modelSource - The source of the ExecuTorch model binary. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. @@ -13,7 +13,7 @@ export interface ExecutorchModuleProps { } /** - * Return type for the `useExecutorchModule` hook. + * Return type for the `useExecutorch` hook. * Manages the state and core execution methods for a general ExecuTorch model. * @category Types */ diff --git a/packages/react-native-executorch/src/types/tti.ts b/packages/react-native-executorch/src/types/tti.ts index 91f0f2dc55..b50c2a4759 100644 --- a/packages/react-native-executorch/src/types/tti.ts +++ b/packages/react-native-executorch/src/types/tti.ts @@ -84,6 +84,17 @@ export interface TextToImageType { * @returns A Promise that resolves to a `file://` URI pointing to the generated PNG on the device, or an empty string if generation was interrupted. * @throws {RnExecutorchError} If the model is not loaded or is currently generating another image. */ + forward: ( + input: string, + imageSize?: number, + numSteps?: number, + seed?: number + ) => Promise; + + /** + * @deprecated Use `forward` instead. `generate` is kept as a temporary alias + * for backward compatibility and will be removed in a future release. + */ generate: ( input: string, imageSize?: number, diff --git a/packages/react-native-executorch/src/types/tts.ts b/packages/react-native-executorch/src/types/tts.ts index a2dbd1905f..ea4dd834b4 100644 --- a/packages/react-native-executorch/src/types/tts.ts +++ b/packages/react-native-executorch/src/types/tts.ts @@ -77,6 +77,17 @@ export interface TextToSpeechModelConfig { phonemizerConfig: TextToSpeechPhonemizerConfig; } +/** + * Props for the `useTextToSpeech` hook. + * @category Types + * @property {TextToSpeechModelConfig} model - The Kokoro voice / model bundle to load. + * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. + */ +export interface TextToSpeechProps { + model: TextToSpeechModelConfig; + preventLoad?: boolean; +} + /** * Text to Speech module input definition * @category Types diff --git a/packages/react-native-executorch/tsconfig.test.json b/packages/react-native-executorch/tsconfig.test.json new file mode 100644 index 0000000000..0806fd4813 --- /dev/null +++ b/packages/react-native-executorch/tsconfig.test.json @@ -0,0 +1,11 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "rootDir": ".", + "composite": false, + "noEmit": true, + "types": ["jest", "react", "node"] + }, + "include": ["src", "__tests__"], + "exclude": ["node_modules", "lib"] +} diff --git a/skills/react-native-executorch/SKILL.md b/skills/react-native-executorch/SKILL.md index 6771e1ad0c..1dfdbd6e25 100644 --- a/skills/react-native-executorch/SKILL.md +++ b/skills/react-native-executorch/SKILL.md @@ -1,6 +1,6 @@ --- name: react-native-executorch -description: Build on-device AI features in React Native and Expo apps with React Native ExecuTorch. Use when adding AI to a mobile app without cloud dependencies — chatbots and assistants, image classification, object detection, OCR, semantic or instance segmentation, style transfer, image generation, pose estimation, speech-to-text, text-to-speech, voice activity detection, semantic search with embeddings, tokenization, privacy filtering / PII redaction, or vision-language image understanding. Also use when the user mentions offline AI, on-device ML, privacy-preserving AI, reducing cloud API cost or latency, running models locally on mobile, or downloading and managing ML models. Covers initExecutorch, every public hook (useLLM, useClassification, useObjectDetection, useOCR, useVerticalOCR, useSemanticSegmentation, useInstanceSegmentation, useStyleTransfer, useTextToImage, useImageEmbeddings, usePoseEstimation, useSpeechToText, useTextToSpeech, useVAD, useTextEmbeddings, useTokenizer, usePrivacyFilter, useExecutorchModule), tool calling, structured output, VLMs, model loading via Expo or bare resource-fetcher adapters, and error handling. +description: Build on-device AI features in React Native and Expo apps with React Native ExecuTorch. Use when adding AI to a mobile app without cloud dependencies — chatbots and assistants, image classification, object detection, OCR, semantic or instance segmentation, style transfer, image generation, pose estimation, speech-to-text, text-to-speech, voice activity detection, semantic search with embeddings, tokenization, privacy filtering / PII redaction, or vision-language image understanding. Also use when the user mentions offline AI, on-device ML, privacy-preserving AI, reducing cloud API cost or latency, running models locally on mobile, or downloading and managing ML models. Covers initExecutorch, every public hook (useLLM, useClassification, useObjectDetection, useOCR, useVerticalOCR, useSemanticSegmentation, useInstanceSegmentation, useStyleTransfer, useTextToImage, useImageEmbeddings, usePoseEstimation, useSpeechToText, useTextToSpeech, useVAD, useTextEmbeddings, useTokenizer, usePrivacyFilter, useExecutorch), tool calling, structured output, VLMs, model loading via Expo or bare resource-fetcher adapters, and error handling. --- # React Native ExecuTorch @@ -46,7 +46,7 @@ What does the feature need? │ └── react-native-rag (sibling library) → see setup.md │ └── Custom `.pte` model not covered by a dedicated hook? - └── useExecutorchModule → see setup.md + └── useExecutorch → see setup.md ``` ## Critical Rules @@ -87,51 +87,51 @@ Full setup, Metro config for bundled `.pte` files, custom adapters, model-loadin ## Hook Quick Reference -| Hook | Purpose | Reference | -|---|---|---| -| `useLLM` | Text generation, chat, tool calling, VLM | [llm.md](./references/llm.md) | -| `useClassification` | Image categorisation | [vision.md](./references/vision.md) | -| `useObjectDetection` | Bounding-box detection (YOLO26, RF-DETR, SSDLite) | [vision.md](./references/vision.md) | -| `useSemanticSegmentation` | Per-pixel class segmentation | [vision.md](./references/vision.md) | -| `useInstanceSegmentation` | Per-instance segmentation | [vision.md](./references/vision.md) | -| `usePoseEstimation` | COCO 17-keypoint human pose | [vision.md](./references/vision.md) | -| `useStyleTransfer` | Artistic image filters | [vision.md](./references/vision.md) | -| `useTextToImage` | Stable Diffusion image generation | [vision.md](./references/vision.md) | -| `useImageEmbeddings` | CLIP image embeddings | [vision.md](./references/vision.md) | -| `useOCR` | Horizontal text OCR | [vision.md](./references/vision.md) | -| `useVerticalOCR` | Vertical text OCR (experimental, CJK) | [vision.md](./references/vision.md) | -| `useTextEmbeddings` | Sentence embeddings for similarity / RAG | [vision.md](./references/vision.md) | -| `useSpeechToText` | Whisper transcription (batch + streaming) | [speech.md](./references/speech.md) | -| `useTextToSpeech` | Kokoro TTS (batch + streaming, phoneme input) | [speech.md](./references/speech.md) | -| `useVAD` | FSMN voice activity detection | [speech.md](./references/speech.md) | -| `useTokenizer` | HuggingFace-compatible tokenization | [setup.md](./references/setup.md) | -| `usePrivacyFilter` | On-device PII / privacy redaction | [setup.md](./references/setup.md) | -| `useExecutorchModule` | Custom `.pte` model inference | [setup.md](./references/setup.md) | +| Hook | Purpose | Reference | +| ------------------------- | ------------------------------------------------- | ----------------------------------- | +| `useLLM` | Text generation, chat, tool calling, VLM | [llm.md](./references/llm.md) | +| `useClassification` | Image categorisation | [vision.md](./references/vision.md) | +| `useObjectDetection` | Bounding-box detection (YOLO26, RF-DETR, SSDLite) | [vision.md](./references/vision.md) | +| `useSemanticSegmentation` | Per-pixel class segmentation | [vision.md](./references/vision.md) | +| `useInstanceSegmentation` | Per-instance segmentation | [vision.md](./references/vision.md) | +| `usePoseEstimation` | COCO 17-keypoint human pose | [vision.md](./references/vision.md) | +| `useStyleTransfer` | Artistic image filters | [vision.md](./references/vision.md) | +| `useTextToImage` | Stable Diffusion image generation | [vision.md](./references/vision.md) | +| `useImageEmbeddings` | CLIP image embeddings | [vision.md](./references/vision.md) | +| `useOCR` | Horizontal text OCR | [vision.md](./references/vision.md) | +| `useVerticalOCR` | Vertical text OCR (experimental, CJK) | [vision.md](./references/vision.md) | +| `useTextEmbeddings` | Sentence embeddings for similarity / RAG | [vision.md](./references/vision.md) | +| `useSpeechToText` | Whisper transcription (batch + streaming) | [speech.md](./references/speech.md) | +| `useTextToSpeech` | Kokoro TTS (batch + streaming, phoneme input) | [speech.md](./references/speech.md) | +| `useVAD` | FSMN voice activity detection | [speech.md](./references/speech.md) | +| `useTokenizer` | HuggingFace-compatible tokenization | [setup.md](./references/setup.md) | +| `usePrivacyFilter` | On-device PII / privacy redaction | [setup.md](./references/setup.md) | +| `useExecutorch` | Custom `.pte` model inference | [setup.md](./references/setup.md) | Every hook also has a non-React `Module` counterpart (e.g. `LLMModule.fromModelName(...)`, `ClassificationModule.fromModelName(...)`) for use outside React components. ## Common Pitfalls -| Symptom | Likely cause | Fix | -|---|---|---| -| `ResourceFetcherAdapterNotInitialized` | `initExecutorch` not called | Call it at app entry with an adapter | -| `ModuleNotLoaded` | Inference before model finished loading | Gate calls on `isReady` | -| `MemoryAllocationFailed` on launch | Model too large for device | Switch to `_QUANTIZED` variant or smaller parameter count | -| App crashes on screen navigation | Unmount during active generation | `llm.interrupt()` and await `isGenerating === false` | -| Whisper produces garbled text | Wrong sample rate | Decode audio at 16 kHz mono | -| TTS output sounds chipmunked | Playback context at wrong rate | Create `AudioContext({ sampleRate: 24000 })` | -| Build fails on iOS simulator (release) | Simulator lacks Metal APIs | Build release on real device | +| Symptom | Likely cause | Fix | +| -------------------------------------- | --------------------------------------- | --------------------------------------------------------- | +| `ResourceFetcherAdapterNotInitialized` | `initExecutorch` not called | Call it at app entry with an adapter | +| `ModuleNotLoaded` | Inference before model finished loading | Gate calls on `isReady` | +| `MemoryAllocationFailed` on launch | Model too large for device | Switch to `_QUANTIZED` variant or smaller parameter count | +| App crashes on screen navigation | Unmount during active generation | `llm.interrupt()` and await `isGenerating === false` | +| Whisper produces garbled text | Wrong sample rate | Decode audio at 16 kHz mono | +| TTS output sounds chipmunked | Playback context at wrong rate | Create `AudioContext({ sampleRate: 24000 })` | +| Build fails on iOS simulator (release) | Simulator lacks Metal APIs | Build release on real device | Full error code list and recovery patterns: [setup.md](./references/setup.md). ## References -| File | When to read | -|---|---| -| [llm.md](./references/llm.md) | `useLLM` functional + managed modes, tool calling, structured output (JSON Schema / Zod), interrupting, vision-language models, generation config | -| [vision.md](./references/vision.md) | Image classification, object detection, semantic + instance segmentation, pose estimation, OCR (horizontal + vertical), style transfer, text-to-image, image + text embeddings | -| [speech.md](./references/speech.md) | Speech-to-text (Whisper batch + streaming with timestamps), text-to-speech (Kokoro batch + streaming, phoneme input, voice catalogue), voice activity detection, audio sample-rate requirements | -| [setup.md](./references/setup.md) | `initExecutorch`, Expo / bare resource-fetcher adapters, model loading strategies, Metro config, error codes and recovery, `useExecutorchModule` for custom `.pte` models, `useTokenizer`, `usePrivacyFilter`, full model catalogue | +| File | When to read | +| ----------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [llm.md](./references/llm.md) | `useLLM` functional + managed modes, tool calling, structured output (JSON Schema / Zod), interrupting, vision-language models, generation config | +| [vision.md](./references/vision.md) | Image classification, object detection, semantic + instance segmentation, pose estimation, OCR (horizontal + vertical), style transfer, text-to-image, image + text embeddings | +| [speech.md](./references/speech.md) | Speech-to-text (Whisper batch + streaming with timestamps), text-to-speech (Kokoro batch + streaming, phoneme input, voice catalogue), voice activity detection, audio sample-rate requirements | +| [setup.md](./references/setup.md) | `initExecutorch`, Expo / bare resource-fetcher adapters, model loading strategies, Metro config, error codes and recovery, `useExecutorch` for custom `.pte` models, `useTokenizer`, `usePrivacyFilter`, full model catalogue | ## External Resources diff --git a/skills/react-native-executorch/references/setup.md b/skills/react-native-executorch/references/setup.md index 987e0a677d..f1fddc4d79 100644 --- a/skills/react-native-executorch/references/setup.md +++ b/skills/react-native-executorch/references/setup.md @@ -99,11 +99,18 @@ How large is the model? `models..({ quant?, backend? })` — typed accessors that resolve to the right URL and backend per platform. Default is the quantized variant when one is published; iOS prefers CoreML, Android prefers XNNPACK, for multi-backend models. ```tsx -import { useLLM, useObjectDetection, useOCR, models } from 'react-native-executorch'; - -useLLM({ model: models.llm.llama3_2_3b() }); // platform default, quantized +import { + useLLM, + useObjectDetection, + useOCR, + models, +} from 'react-native-executorch'; + +useLLM({ model: models.llm.llama3_2_3b() }); // platform default, quantized useLLM({ model: models.llm.llama3_2_3b({ quant: false }) }); // full precision -useObjectDetection({ model: models.object_detection.rf_detr_nano({ backend: 'xnnpack' }) }); +useObjectDetection({ + model: models.object_detection.rf_detr_nano({ backend: 'xnnpack' }), +}); useOCR({ model: models.ocr.craft({ language: 'en' }) }); ``` @@ -124,7 +131,7 @@ Hooks expose `downloadProgress` (0–1): ```tsx const llm = useLLM({ model: models.llm.llama3_2_1b() }); -{Math.round(llm.downloadProgress * 100)}% +{Math.round(llm.downloadProgress * 100)}%; ``` --- @@ -152,7 +159,9 @@ await ExpoResourceFetcher.cancelFetching('https://…/model.pte'); const files = await ExpoResourceFetcher.listDownloadedFiles(); const models = await ExpoResourceFetcher.listDownloadedModels(); -const bytes = await ExpoResourceFetcher.getFilesTotalSize('https://…/model.pte'); +const bytes = await ExpoResourceFetcher.getFilesTotalSize( + 'https://…/model.pte' +); await ExpoResourceFetcher.deleteResources('https://…/model.pte'); ``` @@ -166,25 +175,28 @@ Downloaded files are stored in the app's documents directory. All errors inherit from `RnExecutorchError` with a `code` from `RnExecutorchErrorCode`. For the full table, webfetch [Error Handling](https://docs.swmansion.com/react-native-executorch/docs/utilities/error-handling). -| Error code | When | Recovery | -|---|---|---| -| `ResourceFetcherAdapterNotInitialized` | Any API used before `initExecutorch()` | Call `initExecutorch({ resourceFetcher })` at app entry | -| `ModuleNotLoaded` | Inference before `isReady === true` | Gate on `isReady` | -| `ModelGenerating` | New inference while one is running | Wait or call `interrupt()` | -| `InvalidConfig` | Bad params (e.g. `topp > 1`) | Validate config | -| `ResourceFetcherDownloadFailed` | Network error during download | Retry with backoff | -| `MemoryAllocationFailed` | Model too large for device | Switch to a smaller / quantized accessor | -| `DownloadInterrupted` | Download did not complete | Retry | -| `StreamingNotStarted` | `streamInsert` before `stream()` is active | Start `stream()` first | -| `StreamingInProgress` | `stream()` while one is active | Wait or call `streamStop()` | -| `InvalidUserInput` | Empty / malformed input | Validate before calling | -| `FileReadFailed` | Bad image path, unsupported format | Verify path and format | -| `LanguageNotSupported` | OCR / multilingual model asked for an unpublished language | Use a supported code | +| Error code | When | Recovery | +| -------------------------------------- | ---------------------------------------------------------- | ------------------------------------------------------- | +| `ResourceFetcherAdapterNotInitialized` | Any API used before `initExecutorch()` | Call `initExecutorch({ resourceFetcher })` at app entry | +| `ModuleNotLoaded` | Inference before `isReady === true` | Gate on `isReady` | +| `ModelGenerating` | New inference while one is running | Wait or call `interrupt()` | +| `InvalidConfig` | Bad params (e.g. `topp > 1`) | Validate config | +| `ResourceFetcherDownloadFailed` | Network error during download | Retry with backoff | +| `MemoryAllocationFailed` | Model too large for device | Switch to a smaller / quantized accessor | +| `DownloadInterrupted` | Download did not complete | Retry | +| `StreamingNotStarted` | `streamInsert` before `stream()` is active | Start `stream()` first | +| `StreamingInProgress` | `stream()` while one is active | Wait or call `streamStop()` | +| `InvalidUserInput` | Empty / malformed input | Validate before calling | +| `FileReadFailed` | Bad image path, unsupported format | Verify path and format | +| `LanguageNotSupported` | OCR / multilingual model asked for an unpublished language | Use a supported code | ### Pattern ```tsx -import { RnExecutorchError, RnExecutorchErrorCode } from 'react-native-executorch'; +import { + RnExecutorchError, + RnExecutorchErrorCode, +} from 'react-native-executorch'; try { await model.forward(imageUri); @@ -211,9 +223,9 @@ try { --- -## Custom models — `useExecutorchModule` +## Custom models — `useExecutorch` -For `.pte` models not covered by a dedicated hook, use `useExecutorchModule` to run arbitrary tensor I/O. +For `.pte` models not covered by a dedicated hook, use `useExecutorch` to run arbitrary tensor I/O. ### Exporting @@ -224,9 +236,9 @@ For `.pte` models not covered by a dedicated hook, use `useExecutorchModule` to ### Running ```tsx -import { useExecutorchModule, ScalarType } from 'react-native-executorch'; +import { useExecutorch, ScalarType } from 'react-native-executorch'; -const m = useExecutorchModule({ +const m = useExecutorch({ modelSource: require('../assets/custom_model.pte'), }); @@ -253,7 +265,9 @@ For services or non-React contexts, use the module class directly via `fromModel ```ts import { ClassificationModule, models } from 'react-native-executorch'; -const m = await ClassificationModule.fromModelName(models.classification.efficientnet_v2_s()); +const m = await ClassificationModule.fromModelName( + models.classification.efficientnet_v2_s() +); ``` Every hook has a corresponding module: `LLMModule`, `ObjectDetectionModule`, `OCRModule`, `SpeechToTextModule`, `TextToSpeechModule`, etc. @@ -267,7 +281,9 @@ HuggingFace-compatible BPE / WordPiece tokenizer. Mostly useful for counting tok ```tsx import { useTokenizer, models } from 'react-native-executorch'; -const tokenizer = useTokenizer({ tokenizer: models.text_embedding.all_minilm_l6_v2() }); +const tokenizer = useTokenizer({ + tokenizer: models.text_embedding.all_minilm_l6_v2(), +}); const ids = await tokenizer.encode('Hello, world!'); const text = await tokenizer.decode(ids); @@ -321,11 +337,18 @@ npm install @react-native-rag/op-sqlite ``` ```tsx -import { useRAG, MemoryVectorStore, ExecuTorchEmbeddings, ExecuTorchLLM } from 'react-native-rag'; +import { + useRAG, + MemoryVectorStore, + ExecuTorchEmbeddings, + ExecuTorchLLM, +} from 'react-native-rag'; import { models } from 'react-native-executorch'; const vectorStore = new MemoryVectorStore({ - embeddings: new ExecuTorchEmbeddings(models.text_embedding.all_minilm_l6_v2()), + embeddings: new ExecuTorchEmbeddings( + models.text_embedding.all_minilm_l6_v2() + ), }); const llm = new ExecuTorchLLM(models.llm.lfm2_5_1_2b_instruct()); @@ -349,11 +372,11 @@ Both `ExecuTorchEmbeddings` and `ExecuTorchLLM` accept any model accessor from t ## Device constraints -| Tier | Parameter range | Examples | -|---|---|---| -| Low-end | 135M–500M | `models.llm.smollm2_1_135m`, `models.llm.smollm2_1_360m` | -| Mid-range | 500M–1.7B | `models.llm.qwen3_0_6b`, `models.llm.smollm2_1_1_7b`, `models.llm.llama3_2_1b` | -| High-end | 1.7B–4B | `models.llm.qwen3_4b`, `models.llm.phi_4_mini_4b`, `models.llm.llama3_2_3b` | +| Tier | Parameter range | Examples | +| --------- | --------------- | ------------------------------------------------------------------------------ | +| Low-end | 135M–500M | `models.llm.smollm2_1_135m`, `models.llm.smollm2_1_360m` | +| Mid-range | 500M–1.7B | `models.llm.qwen3_0_6b`, `models.llm.smollm2_1_1_7b`, `models.llm.llama3_2_1b` | +| High-end | 1.7B–4B | `models.llm.qwen3_4b`, `models.llm.phi_4_mini_4b`, `models.llm.llama3_2_3b` | For per-model memory and inference benchmarks: webfetch [Benchmarks](https://docs.swmansion.com/react-native-executorch/docs/benchmarks/inference-time). @@ -373,6 +396,6 @@ For per-model memory and inference benchmarks: webfetch [Benchmarks](https://doc - [Loading models](https://docs.swmansion.com/react-native-executorch/docs/fundamentals/loading-models) - [Resource fetcher](https://docs.swmansion.com/react-native-executorch/docs/utilities/resource-fetcher) - [Error handling](https://docs.swmansion.com/react-native-executorch/docs/utilities/error-handling) -- [useExecutorchModule API reference](https://docs.swmansion.com/react-native-executorch/docs/api-reference/functions/useExecutorchModule) +- [useExecutorch API reference](https://docs.swmansion.com/react-native-executorch/docs/api-reference/functions/useExecutorch) - [useTokenizer API reference](https://docs.swmansion.com/react-native-executorch/docs/api-reference/functions/useTokenizer) - [usePrivacyFilter API reference](https://docs.swmansion.com/react-native-executorch/docs/api-reference/functions/usePrivacyFilter) diff --git a/skills/react-native-executorch/references/vision.md b/skills/react-native-executorch/references/vision.md index e01564de0e..e9e7f913ef 100644 --- a/skills/react-native-executorch/references/vision.md +++ b/skills/react-native-executorch/references/vision.md @@ -13,7 +13,9 @@ Every hook accepts image input as one of: a remote URL (`https://…`), a local ```tsx import { useClassification, models } from 'react-native-executorch'; -const model = useClassification({ model: models.classification.efficientnet_v2_s() }); +const model = useClassification({ + model: models.classification.efficientnet_v2_s(), +}); const labels = await model.forward('https://example.com/puppy.png'); // labels: Record — ImageNet1k label → probability @@ -36,8 +38,8 @@ const model = useObjectDetection({ model: models.object_detection.yolo26n() }); const detections = await model.forward('https://example.com/street.jpg', { detectionThreshold: 0.5, // minimum confidence (0–1) - iouThreshold: 0.45, // NMS aggressiveness (0–1) - inputSize: 640, // for multi-size YOLO models (384 / 512 / 640) + iouThreshold: 0.45, // NMS aggressiveness (0–1) + inputSize: 640, // for multi-size YOLO models (384 / 512 / 640) classesOfInterest: ['PERSON', 'CAR'], // filter }); @@ -59,7 +61,11 @@ YOLO models support multiple input sizes — call `model.getAvailableInputSizes( Pixel-level classification. ```tsx -import { useSemanticSegmentation, models, DeeplabLabel } from 'react-native-executorch'; +import { + useSemanticSegmentation, + models, + DeeplabLabel, +} from 'react-native-executorch'; const model = useSemanticSegmentation({ model: models.semantic_segmentation.deeplab_v3_resnet50(), @@ -67,8 +73,8 @@ const model = useSemanticSegmentation({ // Pass classesOfInterest + resizeToInput to also get per-class probability maps const out = await model.forward(imageUri, ['CAT', 'DOG', 'PERSON'], true); -const argmax = out[DeeplabLabel.ARGMAX]; // class id per pixel -const catProbs = out['CAT']; // probability per pixel +const argmax = out[DeeplabLabel.ARGMAX]; // class id per pixel +const catProbs = out['CAT']; // probability per pixel ``` **Tradeoff:** `resizeToInput: true` upsamples to the original image size — more memory and slower. With `false`, indices map to a 224×224 grid. @@ -101,7 +107,11 @@ const instances = await model.forward('https://example.com/street.jpg'); Detects humans and their COCO 17-keypoint skeletons (nose, eyes, ears, shoulders, elbows, wrists, hips, knees, ankles). ```tsx -import { usePoseEstimation, models, CocoKeypoint } from 'react-native-executorch'; +import { + usePoseEstimation, + models, + CocoKeypoint, +} from 'react-native-executorch'; const model = usePoseEstimation({ model: models.pose_estimation.yolo26n() }); @@ -134,6 +144,7 @@ for (const d of detections) { ``` `OCRDetection`: + ```ts interface OCRDetection { bbox: { x: number; y: number }[]; // 4 corner points (supports rotated/skewed text) @@ -188,9 +199,11 @@ On-device Stable Diffusion (BK-SDM tiny). ```tsx import { useTextToImage, models } from 'react-native-executorch'; -const model = useTextToImage({ model: models.image_generation.bk_sdm_tiny_vpred_256() }); +const model = useTextToImage({ + model: models.image_generation.bk_sdm_tiny_vpred_256(), +}); -const image = await model.generate('a medieval castle by the sea', 256, 25); +const image = await model.forward('a medieval castle by the sea', 256, 25); // image: base64 PNG. Render with ``` @@ -236,15 +249,15 @@ const v2 = await model.forward('Greetings everyone'); const cosine = v1.reduce((s, x, i) => s + x * v2[i], 0); // pre-normalized ``` -| Accessor | Max tokens | Dim | Use case | -|---|---|---|---| -| `models.text_embedding.all_minilm_l6_v2` | 254 | 384 | General purpose | -| `models.text_embedding.all_mpnet_base_v2` | 382 | 768 | Higher quality, slower | -| `models.text_embedding.multi_qa_minilm_l6_cos_v1` | 509 | 384 | Q&A / semantic search | -| `models.text_embedding.multi_qa_mpnet_base_dot_v1` | 510 | 768 | Q&A / semantic search | -| `models.text_embedding.distiluse_base_multilingual_cased_v2` | 128 | 512 | Multilingual | -| `models.text_embedding.paraphrase_multilingual_minilm_l12_v2` | 128 | 384 | Multilingual paraphrase | -| `models.text_embedding.clip_vit_base_patch32_text` | 74 | 512 | Pair with image embeddings (CLIP) | +| Accessor | Max tokens | Dim | Use case | +| ------------------------------------------------------------- | ---------- | --- | --------------------------------- | +| `models.text_embedding.all_minilm_l6_v2` | 254 | 384 | General purpose | +| `models.text_embedding.all_mpnet_base_v2` | 382 | 768 | Higher quality, slower | +| `models.text_embedding.multi_qa_minilm_l6_cos_v1` | 509 | 384 | Q&A / semantic search | +| `models.text_embedding.multi_qa_mpnet_base_dot_v1` | 510 | 768 | Q&A / semantic search | +| `models.text_embedding.distiluse_base_multilingual_cased_v2` | 128 | 512 | Multilingual | +| `models.text_embedding.paraphrase_multilingual_minilm_l12_v2` | 128 | 384 | Multilingual paraphrase | +| `models.text_embedding.clip_vit_base_patch32_text` | 74 | 512 | Pair with image embeddings (CLIP) | Text exceeding `Max tokens` is truncated. Use `useTokenizer` (see [setup.md](./setup.md)) to count first.