From 1355acdbc3e4ddc8dd86fff8443e8173984ceb8b Mon Sep 17 00:00:00 2001 From: definenoob Date: Sun, 8 Mar 2026 06:39:16 -0400 Subject: [PATCH 1/6] initial commit --- Cargo.lock | 418 +++++++++++++++++- Cargo.toml | 2 + crates/bindings-sys/src/lib.rs | 154 +++++++ crates/bindings/src/lib.rs | 10 + crates/bindings/src/onnx.rs | 102 +++++ crates/core/Cargo.toml | 1 + crates/core/src/error.rs | 2 + crates/core/src/host/host_controller.rs | 9 + crates/core/src/host/instance_env.rs | 4 + crates/core/src/host/mod.rs | 5 + crates/core/src/host/onnx.rs | 111 +++++ crates/core/src/host/wasm_common.rs | 5 + .../src/host/wasm_common/module_host_actor.rs | 10 +- .../src/host/wasmtime/wasm_instance_env.rs | 107 +++++ .../core/src/host/wasmtime/wasmtime_module.rs | 2 +- crates/core/src/module_host_context.rs | 2 + crates/lib/src/lib.rs | 1 + crates/lib/src/onnx.rs | 26 ++ crates/primitives/src/errno.rs | 2 + 19 files changed, 960 insertions(+), 13 deletions(-) create mode 100644 crates/bindings/src/onnx.rs create mode 100644 crates/core/src/host/onnx.rs create mode 100644 crates/lib/src/onnx.rs diff --git a/Cargo.lock b/Cargo.lock index a3dfa2fcac4..fd7cc85075f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,6 +162,18 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33954243bd79057c2de7338850b85983a44588021f8a5fee574a8888c6de4344" +[[package]] +name = "anymap2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" + +[[package]] +name = "anymap3" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "170433209e817da6aae2c51aa0dd443009a613425dd041ebfb2492d1c4c11a25" + [[package]] name = "append-only-vec" version = "0.1.8" @@ -502,15 +514,30 @@ dependencies = [ "syn 2.0.107", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec 0.6.3", +] + [[package]] name = "bit-set" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ - "bit-vec", + "bit-vec 0.8.0", ] +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bit-vec" version = "0.8.0" @@ -1673,6 +1700,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive-new" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3418329ca0ad70234b9735dc4ceed10af4df60eff9c8e7b06cb5e520d92c3535" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive-new" version = "0.7.0" @@ -1849,12 +1887,24 @@ dependencies = [ "syn 2.0.107", ] +[[package]] +name = "doc-comment" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780955b8b195a21ab8e4ac6b60dd1dbdcec1dc6c51c0617964b08c81785e12c9" + [[package]] name = "dotenvy" version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + [[package]] name = "dragonbox_ecma" version = "0.1.0" @@ -1879,6 +1929,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "dyn-hash" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15401da73a9ed8c80e3b2d4dc05fe10e7b72d7243b9f614e516a44fa99986e88" + [[package]] name = "educe" version = "0.4.23" @@ -2118,7 +2174,7 @@ version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "998b056554fbe42e03ae0e152895cd1a7e1002aec800fdc6635d20270260c46f" dependencies = [ - "bit-set", + "bit-set 0.8.0", "regex-automata", "regex-syntax", ] @@ -2129,7 +2185,7 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72cf461f865c862bb7dc573f643dd6a2b6842f7c30b07882b56bd148cc2761b8" dependencies = [ - "bit-set", + "bit-set 0.8.0", "regex-automata", "regex-syntax", ] @@ -2607,6 +2663,7 @@ checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" dependencies = [ "cfg-if", "crunchy", + "num-traits", "zerocopy", ] @@ -3585,6 +3642,16 @@ dependencies = [ "libc", ] +[[package]] +name = "kstring" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "558bf9508a558512042d3095138b1f7b8fe90c5467d94f9f1da28b3731c5dbd1" +dependencies = [ + "serde", + "static_assertions", +] + [[package]] name = "lazy-regex" version = "3.4.1" @@ -3769,6 +3836,63 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "liquid" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e9338405fdbc0bce9b01695b2a2ef6b20eca5363f385d47bce48ddf8323cc25" +dependencies = [ + "doc-comment", + "liquid-core", + "liquid-derive", + "liquid-lib", + "serde", +] + +[[package]] +name = "liquid-core" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feb8fed70857010ed9016ed2ce5a7f34e7cc51d5d7255c9c9dc2e3243e490b42" +dependencies = [ + "anymap2", + "itertools 0.13.0", + "kstring", + "liquid-derive", + "num-traits", + "pest", + "pest_derive", + "regex", + "serde", + "time", +] + +[[package]] +name = "liquid-derive" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b51f1d220e3fa869e24cfd75915efe3164bd09bb11b3165db3f37f57bf673e3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.107", +] + +[[package]] +name = "liquid-lib" +version = "0.26.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1794b5605e9f8864a8a4f41aa97976b42512cc81093f8c885d29fb94c6c556" +dependencies = [ + "itertools 0.13.0", + "liquid-core", + "once_cell", + "percent-encoding", + "regex", + "time", + "unicode-segmentation", +] + [[package]] name = "litemap" version = "0.8.0" @@ -3833,6 +3957,12 @@ dependencies = [ "libc", ] +[[package]] +name = "maplit" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" + [[package]] name = "mappings" version = "0.7.1" @@ -3867,6 +3997,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -4035,6 +4175,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk-context" version = "0.1.1" @@ -5244,7 +5399,7 @@ dependencies = [ "async-trait", "bytes", "chrono", - "derive-new", + "derive-new 0.7.0", "futures", "hex", "lazy-regex", @@ -5406,6 +5561,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +dependencies = [ + "portable-atomic", +] + [[package]] name = "portpicker" version = "0.1.1" @@ -5501,7 +5665,7 @@ dependencies = [ "inferno 0.12.3", "num", "paste", - "prost", + "prost 0.13.5", ] [[package]] @@ -5563,6 +5727,15 @@ dependencies = [ "syn 2.0.107", ] +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-crate" version = "3.4.0" @@ -5638,8 +5811,8 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bb0be07becd10686a0bb407298fb425360a5c44a663774406340c59a22de4ce" dependencies = [ - "bit-set", - "bit-vec", + "bit-set 0.8.0", + "bit-vec 0.8.0", "bitflags 2.10.0", "lazy_static", "num-traits", @@ -5663,6 +5836,16 @@ dependencies = [ "syn 2.0.107", ] +[[package]] +name = "prost" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +dependencies = [ + "bytes", + "prost-derive 0.11.9", +] + [[package]] name = "prost" version = "0.13.5" @@ -5670,7 +5853,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive", + "prost-derive 0.13.5", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools 0.10.5", + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] @@ -5956,6 +6152,16 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + [[package]] name = "rand_distr" version = "0.5.1" @@ -5975,6 +6181,12 @@ dependencies = [ "rand_core 0.9.3", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -6778,6 +6990,20 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "rustix" version = "0.38.44" @@ -6912,6 +7138,15 @@ dependencies = [ "regex", ] +[[package]] +name = "scan_fmt" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b53b0a5db882a8e2fdaae0a43f7b39e7e9082389e978398bdf223a55b581248" +dependencies = [ + "regex", +] + [[package]] name = "schannel" version = "0.1.28" @@ -7932,6 +8167,7 @@ dependencies = [ "tracing-log 0.1.4", "tracing-subscriber", "tracing-tracy", + "tract-onnx", "url", "urlencoding", "uuid", @@ -8280,7 +8516,7 @@ dependencies = [ "itertools 0.12.1", "log", "rand 0.9.2", - "rand_distr", + "rand_distr 0.5.1", "spacetimedb-client-api-messages", "spacetimedb-lib 2.0.3", "thiserror 1.0.69", @@ -8778,6 +9014,23 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + +[[package]] +name = "string-interner" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07f9fdfdd31a0ff38b59deb401be81b73913d76c9cc5b1aed4e1330a223420b9" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "serde", +] + [[package]] name = "string_wizard" version = "0.0.27" @@ -9753,6 +10006,143 @@ dependencies = [ "tracy-client", ] +[[package]] +name = "tract-core" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7b5347639690871b124593a8c8903f1f369531498b8abaebd18eb5c58163971" +dependencies = [ + "anyhow", + "anymap3", + "bit-set 0.5.3", + "derive-new 0.5.9", + "downcast-rs", + "dyn-clone", + "lazy_static", + "log", + "maplit", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "paste", + "rustfft", + "smallvec", + "tract-data", + "tract-linalg", +] + +[[package]] +name = "tract-data" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0a3f476a1804e05708e9bc5e2d29dcab82bad531e357d3d14d7da80fbba0b6d" +dependencies = [ + "anyhow", + "downcast-rs", + "dyn-clone", + "dyn-hash", + "half", + "itertools 0.12.1", + "lazy_static", + "maplit", + "ndarray", + "nom 7.1.3", + "num-integer", + "num-traits", + "parking_lot 0.12.5", + "scan_fmt", + "smallvec", + "string-interner", +] + +[[package]] +name = "tract-hir" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dca047ba1151fe3446fb0194d4b6ddb9ae8f361337c47a267870c53605fbafb" +dependencies = [ + "derive-new 0.5.9", + "log", + "tract-core", +] + +[[package]] +name = "tract-linalg" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb8e0703eb53ef1bbf77050ff261675818dd5f0d6c27044c6e48ede9b845f9e0" +dependencies = [ + "byteorder", + "cc", + "derive-new 0.5.9", + "downcast-rs", + "dyn-clone", + "dyn-hash", + "half", + "lazy_static", + "liquid", + "liquid-core", + "liquid-derive", + "log", + "num-traits", + "paste", + "rayon", + "scan_fmt", + "smallvec", + "time", + "tract-data", + "unicode-normalization", + "walkdir", +] + +[[package]] +name = "tract-nnef" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72cb88a4367ec2c695610223cf886f01fc1deb5c9a82c7a74b1a5d32dc0b1466" +dependencies = [ + "byteorder", + "flate2", + "log", + "nom 7.1.3", + "tar", + "tract-core", + "walkdir", +] + +[[package]] +name = "tract-onnx" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5830aa672b2aa4dc98a97a36e5988eaf77b3ecee65e2601619588d2ca557008" +dependencies = [ + "bytes", + "derive-new 0.5.9", + "log", + "memmap2", + "num-integer", + "prost 0.11.9", + "smallvec", + "tract-hir", + "tract-nnef", + "tract-onnx-opl", +] + +[[package]] +name = "tract-onnx-opl" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121d3d224c806ba3d941f4bb50943ad33b59d1da5ae704d0e4e76d2808221f96" +dependencies = [ + "getrandom 0.2.16", + "log", + "rand 0.8.5", + "rand_distr 0.4.3", + "rustfft", + "tract-nnef", +] + [[package]] name = "tracy-client" version = "0.16.5" @@ -9773,6 +10163,16 @@ dependencies = [ "cc", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "try-lock" version = "0.2.5" diff --git a/Cargo.toml b/Cargo.toml index 42cf0b85632..6e8f4cb404c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -352,6 +352,8 @@ zstd-framed = { version = "0.1.1", features = ["tokio"] } # potentially very old system version. openssl = { version = "0.10", features = ["vendored"] } +tract-onnx = "0.21" + [workspace.dependencies.wasmtime] version = "39" default-features = false diff --git a/crates/bindings-sys/src/lib.rs b/crates/bindings-sys/src/lib.rs index 95dfbc7e600..4616cf0a434 100644 --- a/crates/bindings-sys/src/lib.rs +++ b/crates/bindings-sys/src/lib.rs @@ -865,6 +865,79 @@ pub mod raw { ) -> u16; } + #[link(wasm_import_module = "spacetime_10.5")] + unsafe extern "C" { + /// Loads an ONNX model by name from the host's model storage. + /// + /// `name_ptr[..name_len]` is a UTF-8 model name (e.g. `"bot_brain"`). + /// The host resolves this to a `.onnx` file on its filesystem, + /// loads and optimizes the model entirely on the host side. + /// The model bytes never enter WASM memory. + /// + /// On success, writes a model handle (u32) to `out[0]` and returns 0. + /// The model handle can be used with [`onnx_run_inference`] and freed with [`onnx_close_model`]. + /// + /// # Traps + /// + /// Traps if: + /// - `name_ptr` is NULL or `name_ptr[..name_len]` is not in bounds of WASM memory. + /// - `name_ptr[..name_len]` is not valid UTF-8. + /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. + /// + /// # Errors + /// + /// Returns an error: + /// + /// - `ONNX_ERROR` if the model could not be found or loaded. + /// In this case, a [`BytesSource`] containing a BSATN-encoded error message `String` + /// is written to `out[0]`. + pub fn onnx_load_model( + name_ptr: *const u8, + name_len: u32, + out: *mut u32, + ) -> u16; + + /// Runs inference on the model identified by `model_handle` + /// with BSATN-encoded input tensors at `input_ptr[..input_len]`. + /// + /// `input_ptr[..input_len]` should contain a BSATN-encoded `Vec`. + /// + /// On success, a [`BytesSource`] is written to `out[0]` containing a BSATN-encoded + /// `Vec` with the inference output, and this function returns 0. + /// + /// # Traps + /// + /// Traps if: + /// - `input_ptr` is NULL or `input_ptr[..input_len]` is not in bounds of WASM memory. + /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. + /// + /// # Errors + /// + /// Returns an error: + /// + /// - `NO_SUCH_MODEL` if `model_handle` does not refer to a loaded model. + /// - `BSATN_DECODE_ERROR` if the input tensors could not be decoded. + /// - `ONNX_ERROR` if inference failed. In this case, a [`BytesSource`] containing + /// a BSATN-encoded error message `String` is written to `out[0]`. + pub fn onnx_run_inference( + model_handle: u32, + input_ptr: *const u8, + input_len: u32, + out: *mut u32, + ) -> u16; + + /// Frees the ONNX model identified by `model_handle`. + /// + /// After this call, the handle is no longer valid. + /// + /// # Errors + /// + /// Returns an error: + /// + /// - `NO_SUCH_MODEL` if `model_handle` does not refer to a loaded model. + pub fn onnx_close_model(model_handle: u32) -> u16; + } + /// What strategy does the database index use? /// /// See also: @@ -1626,3 +1699,84 @@ pub mod procedure { } } } + +/// ONNX inference operations, available from both reducers and procedures. +pub mod onnx { + use super::raw; + + /// Load an ONNX model by name from the host's model storage. + /// + /// The host resolves the name to a `.onnx` file on its filesystem, + /// loads and optimizes the model entirely on the host side. + /// + /// On success, returns `Ok(model_handle)`. + /// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message `String`. + #[inline] + pub fn load_model(name: &str) -> Result { + // The host writes either a model handle (on success) or a BytesSource handle (on error) + // into `out`. Both are u32. BytesSource is #[repr(transparent)] over u32. + let mut out = [raw::BytesSource::INVALID; 1]; + + let res = unsafe { + super::raw::onnx_load_model( + name.as_ptr(), + name.len() as u32, + out.as_mut_ptr().cast(), + ) + }; + + match super::Errno::from_code(res) { + // Success: out[0] is a model handle as raw u32 bits. + // Safety: BytesSource is #[repr(transparent)] over u32. + None => Ok(unsafe { core::mem::transmute::(out[0]) }), + // ONNX_ERROR: out[0] is a BytesSource with the error message. + Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]), + Some(errno) => panic!("{errno}"), + } + } + + /// Run inference on a loaded model with BSATN-encoded input tensors. + /// + /// `input_bsatn` should be a BSATN-encoded `Vec`. + /// + /// On success, returns `Ok(bytes_source)` containing BSATN-encoded output tensors. + /// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message. + #[inline] + pub fn run_inference( + model_handle: u32, + input_bsatn: &[u8], + ) -> Result { + let mut out = [raw::BytesSource::INVALID; 1]; + + let res = unsafe { + super::raw::onnx_run_inference( + model_handle, + input_bsatn.as_ptr(), + input_bsatn.len() as u32, + out.as_mut_ptr().cast(), + ) + }; + + match super::Errno::from_code(res) { + None => Ok(out[0]), + Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]), + Some(errno) if errno == super::Errno::NO_SUCH_MODEL => { + panic!("ONNX model handle {model_handle} is not valid") + } + Some(errno) => panic!("{errno}"), + } + } + + /// Close a loaded ONNX model, freeing its resources. + #[inline] + pub fn close_model(model_handle: u32) { + let res = unsafe { super::raw::onnx_close_model(model_handle) }; + match super::Errno::from_code(res) { + None => {} + Some(errno) if errno == super::Errno::NO_SUCH_MODEL => { + panic!("ONNX model handle {model_handle} is not valid") + } + Some(errno) => panic!("{errno}"), + } + } +} diff --git a/crates/bindings/src/lib.rs b/crates/bindings/src/lib.rs index a74534663ee..e9fc404fa5a 100644 --- a/crates/bindings/src/lib.rs +++ b/crates/bindings/src/lib.rs @@ -11,6 +11,7 @@ mod client_visibility_filter; #[cfg(feature = "unstable")] pub mod http; pub mod log_stopwatch; +pub mod onnx; mod logger; #[cfg(feature = "rand08")] mod rng; @@ -1005,6 +1006,9 @@ pub struct ReducerContext { /// See the [`#[table]`](macro@crate::table) macro for more information. pub db: Local, + /// Methods for performing ONNX inference. + pub onnx: crate::onnx::OnnxClient, + #[cfg(feature = "rand08")] rng: std::cell::OnceCell, /// A counter used for generating UUIDv7 values. @@ -1018,6 +1022,7 @@ impl ReducerContext { pub fn __dummy() -> Self { Self { db: Local {}, + onnx: crate::onnx::OnnxClient {}, sender: Identity::__dummy(), timestamp: Timestamp::UNIX_EPOCH, connection_id: None, @@ -1033,6 +1038,7 @@ impl ReducerContext { fn new(db: Local, sender: Identity, connection_id: Option, timestamp: Timestamp) -> Self { Self { db, + onnx: crate::onnx::OnnxClient {}, sender, timestamp, connection_id, @@ -1179,6 +1185,9 @@ pub struct ProcedureContext { /// Methods for performing HTTP requests. pub http: crate::http::HttpClient, + + /// Methods for performing ONNX inference. + pub onnx: crate::onnx::OnnxClient, // TODO: Change rng? // Complex and requires design because we may want procedure RNG to behave differently from reducer RNG, // as it could actually be seeded by OS randomness rather than a deterministic source. @@ -1199,6 +1208,7 @@ impl ProcedureContext { timestamp, connection_id, http: http::HttpClient {}, + onnx: crate::onnx::OnnxClient {}, #[cfg(feature = "rand08")] rng: std::cell::OnceCell::new(), #[cfg(feature = "rand")] diff --git a/crates/bindings/src/onnx.rs b/crates/bindings/src/onnx.rs new file mode 100644 index 00000000000..ae16cc45a6a --- /dev/null +++ b/crates/bindings/src/onnx.rs @@ -0,0 +1,102 @@ +//! ONNX inference support for SpacetimeDB modules. +//! +//! Load an ONNX model by name and run inference from within reducers or procedures. +//! Models are stored on the host filesystem — the model bytes never enter WASM memory. +//! +//! # Example +//! +//! ```no_run +//! # use spacetimedb::{reducer, ReducerContext, onnx::{OnnxClient, Tensor, ModelHandle}}; +//! // In a reducer: +//! # #[reducer] +//! # fn my_reducer(ctx: &ReducerContext) { +//! // Load a model by name — the host resolves "bot_brain" to a .onnx file on disk. +//! let model = ctx.onnx.load("bot_brain").expect("Failed to load model"); +//! let input = vec![Tensor { +//! shape: vec![1, 10], +//! data: vec![0.0; 10], +//! }]; +//! let output = ctx.onnx.run(&model, &input).expect("Inference failed"); +//! log::info!("Output: {:?}", output[0].data); +//! # } +//! ``` + +use crate::rt::read_bytes_source_as; +use spacetimedb_lib::bsatn; + +pub use spacetimedb_lib::onnx::Tensor; + +/// An opaque handle to a loaded ONNX model on the host. +/// +/// Obtained via [`OnnxClient::load`] and used with [`OnnxClient::run`]. +/// The model is freed when this handle is dropped. +pub struct ModelHandle(u32); + +impl Drop for ModelHandle { + fn drop(&mut self) { + spacetimedb_bindings_sys::onnx::close_model(self.0); + } +} + +/// Client for performing ONNX inference. +/// +/// Access from within reducers via [`ReducerContext::onnx`](crate::ReducerContext) +/// or from procedures via [`ProcedureContext::onnx`](crate::ProcedureContext). +#[non_exhaustive] +pub struct OnnxClient {} + +impl OnnxClient { + /// Load an ONNX model by name from the host's model storage. + /// + /// The host resolves the name to a `.onnx` file on its filesystem + /// (e.g. in the database's `models/` directory), then loads and optimizes it + /// entirely on the host side. The model bytes never enter WASM memory. + /// + /// The returned [`ModelHandle`] can be used with [`OnnxClient::run`] for inference. + /// The model is automatically freed when the handle is dropped. + pub fn load(&self, model_name: &str) -> Result { + match spacetimedb_bindings_sys::onnx::load_model(model_name) { + Ok(handle) => Ok(ModelHandle(handle)), + Err(err_source) => { + let message = read_bytes_source_as::(err_source); + Err(Error { message }) + } + } + } + + /// Run inference on a loaded model. + /// + /// `inputs` are the input tensors for the model, in the order expected by the model's input nodes. + /// Returns the output tensors from the model. + /// + /// Inference runs entirely on the host in native Rust — only the input/output tensor data + /// crosses the WASM boundary. + pub fn run(&self, model: &ModelHandle, inputs: &[Tensor]) -> Result, Error> { + let input_bsatn = bsatn::to_vec(inputs).expect("Failed to BSATN-serialize input tensors"); + + match spacetimedb_bindings_sys::onnx::run_inference(model.0, &input_bsatn) { + Ok(output_source) => { + let output = read_bytes_source_as::>(output_source); + Ok(output) + } + Err(err_source) => { + let message = read_bytes_source_as::(err_source); + Err(Error { message }) + } + } + } +} + +/// An error from ONNX model loading or inference. +#[derive(Clone, Debug)] +pub struct Error { + message: String, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for Error {} diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e0934d245b8..bf3737b21e9 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -120,6 +120,7 @@ url.workspace = true urlencoding.workspace = true uuid.workspace = true v8.workspace = true +tract-onnx.workspace = true wasmtime.workspace = true wasmtime-internal-fiber.workspace = true jwks.workspace = true diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 8ebee4ce4af..fc3525e0b01 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -306,6 +306,8 @@ pub enum NodesError { ScheduleError(#[source] ScheduleError), #[error("HTTP request failed: {0}")] HttpError(String), + #[error("ONNX inference failed: {0}")] + OnnxError(String), } impl From for NodesError { diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index 82eb265c393..74222faf7d3 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -453,6 +453,7 @@ impl HostController { this.energy_monitor.clone(), this.unregister_fn(replica_id), this.db_cores.take(), + Some(this.data_dir.clone()), ) .await?; @@ -709,6 +710,7 @@ async fn make_module_host( energy_monitor: Arc, unregister: impl Fn() + Send + Sync + 'static, core: AllocatedJobCore, + data_dir: Option>, ) -> anyhow::Result<(Program, ModuleHost)> { // `make_actor` is blocking, as it needs to compile the wasm to native code, // which may be computationally expensive - sometimes up to 1s for a large module. @@ -722,6 +724,7 @@ async fn make_module_host( scheduler, program_hash: program.hash, energy_monitor, + data_dir, }; match HostType::from(program.kind) { @@ -770,6 +773,7 @@ struct ModuleLauncher { runtimes: Arc, core: AllocatedJobCore, bsatn_rlb_pool: BsatnRowListBuilderPool, + data_dir: Option>, } impl ModuleLauncher { @@ -801,6 +805,7 @@ impl ModuleLauncher { self.energy_monitor, self.on_panic, self.core, + self.data_dir, ) .await?; @@ -988,6 +993,7 @@ impl Host { runtimes: runtimes.clone(), core: host_controller.db_cores.take(), bsatn_rlb_pool: bsatn_rlb_pool.clone(), + data_dir: Some(data_dir.clone()), } .launch_module() .await?; @@ -1084,6 +1090,7 @@ impl Host { runtimes: runtimes.clone(), core, bsatn_rlb_pool, + data_dir: None, } .launch_module() .await @@ -1110,6 +1117,7 @@ impl Host { energy_monitor: Arc, on_panic: impl Fn() + Send + Sync + 'static, core: AllocatedJobCore, + data_dir: Option>, ) -> anyhow::Result { let replica_ctx = &self.replica_ctx; let (scheduler, scheduler_starter) = Scheduler::open(self.replica_ctx.relational_db.clone()); @@ -1122,6 +1130,7 @@ impl Host { energy_monitor, on_panic, core, + data_dir, ) .await?; diff --git a/crates/core/src/host/instance_env.rs b/crates/core/src/host/instance_env.rs index a47f035cf34..338f270b87f 100644 --- a/crates/core/src/host/instance_env.rs +++ b/crates/core/src/host/instance_env.rs @@ -53,6 +53,9 @@ pub struct InstanceEnv { in_anon_tx: bool, /// A procedure's last known transaction offset. procedure_last_tx_offset: Option, + /// Directory on the host filesystem where ONNX model files are stored. + /// Set during module initialization if model storage is configured. + pub models_dir: Option, } /// `InstanceEnv` needs to be `Send` because it is created on the host thread @@ -237,6 +240,7 @@ impl InstanceEnv { func_name: None, in_anon_tx: false, procedure_last_tx_offset: None, + models_dir: None, } } diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index 0daa9c359bc..32169711d7d 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -13,6 +13,7 @@ use spacetimedb_schema::def::ModuleDef; mod disk_storage; mod host_controller; mod module_common; +pub mod onnx; #[allow(clippy::too_many_arguments)] pub mod module_host; pub mod scheduler; @@ -194,4 +195,8 @@ pub enum AbiCall { ProcedureCommitMutTransaction, ProcedureAbortMutTransaction, ProcedureHttpRequest, + + OnnxLoadModel, + OnnxRunInference, + OnnxCloseModel, } diff --git a/crates/core/src/host/onnx.rs b/crates/core/src/host/onnx.rs new file mode 100644 index 00000000000..772d6ea6f4d --- /dev/null +++ b/crates/core/src/host/onnx.rs @@ -0,0 +1,111 @@ +//! Host-side ONNX inference using tract-onnx. +//! +//! Provides [`OnnxModel`], which wraps a loaded and optimized tract model +//! and can run inference with tensors passed from WASM modules. +//! +//! Models are loaded from the host filesystem by name — the model bytes +//! never enter WASM memory. Only input/output tensor data crosses the boundary. + +use crate::host::instance_env::InstanceEnv; +use spacetimedb_lib::onnx::Tensor as StdbTensor; +use tract_onnx::prelude::*; + +/// A loaded and optimized ONNX model, ready for inference. +pub struct OnnxModel { + model: SimplePlan, Graph>>, +} + +impl OnnxModel { + /// Load an ONNX model by name from the host's model storage. + /// + /// Resolves the name to `{models_dir}/{name}.onnx` on the host filesystem, + /// reads the file, parses, optimizes, and compiles it into a runnable plan. + /// The model bytes never enter WASM memory. + pub fn load_by_name(name: &str, instance_env: &InstanceEnv) -> Result { + // Validate the model name to prevent path traversal. + if name.contains('/') || name.contains('\\') || name.contains("..") || name.is_empty() { + return Err(OnnxError(format!("Invalid model name: {name:?}"))); + } + + let models_dir = instance_env + .models_dir + .as_ref() + .ok_or_else(|| OnnxError("ONNX models directory not configured".into()))?; + + let model_path = models_dir.join(format!("{name}.onnx")); + + if !model_path.exists() { + return Err(OnnxError(format!( + "Model file not found: {}", + model_path.display() + ))); + } + + let model_bytes = std::fs::read(&model_path) + .map_err(|e| OnnxError(format!("Failed to read model file {}: {e}", model_path.display())))?; + + Self::load_from_bytes(&model_bytes) + } + + /// Load an ONNX model from raw bytes. + fn load_from_bytes(model_bytes: &[u8]) -> Result { + let model = tract_onnx::onnx() + .model_for_read(&mut std::io::Cursor::new(model_bytes)) + .map_err(|e| OnnxError(format!("Failed to parse ONNX model: {e}")))? + .into_optimized() + .map_err(|e| OnnxError(format!("Failed to optimize ONNX model: {e}")))? + .into_runnable() + .map_err(|e| OnnxError(format!("Failed to compile ONNX model: {e}")))?; + + Ok(OnnxModel { model }) + } + + /// Run inference with the given input tensors. + /// + /// Returns the output tensors from the model. + pub fn run(&self, inputs: &[StdbTensor]) -> Result, OnnxError> { + let tract_inputs: Vec = inputs + .iter() + .map(|t| { + let shape: Vec = t.shape.iter().map(|&d| d as usize).collect(); + let tensor = tract_ndarray::Array::from_shape_vec( + tract_ndarray::IxDyn(&shape), + t.data.clone(), + ) + .map_err(|e| OnnxError(format!("Invalid tensor shape: {e}")))?; + Ok(tensor.into_tvalue()) + }) + .collect::, OnnxError>>()?; + + let result = self + .model + .run(tract_inputs.into()) + .map_err(|e| OnnxError(format!("Inference failed: {e}")))?; + + let outputs: Vec = result + .iter() + .map(|t| { + let shape: Vec = t.shape().iter().map(|&d| d as u32).collect(); + let data: Vec = t + .as_slice::() + .map_err(|e| OnnxError(format!("Output tensor is not f32: {e}")))? + .to_vec(); + Ok(StdbTensor { shape, data }) + }) + .collect::, OnnxError>>()?; + + Ok(outputs) + } +} + +/// An error from ONNX model loading or inference. +#[derive(Debug)] +pub struct OnnxError(pub String); + +impl std::fmt::Display for OnnxError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for OnnxError {} diff --git a/crates/core/src/host/wasm_common.rs b/crates/core/src/host/wasm_common.rs index a5c737d54d6..25c6f9db977 100644 --- a/crates/core/src/host/wasm_common.rs +++ b/crates/core/src/host/wasm_common.rs @@ -358,6 +358,7 @@ pub fn err_to_errno(err: NodesError) -> Result<(NonZeroU16, Option), Nod NodesError::IndexCannotSeekRange => errno::WRONG_INDEX_ALGO, NodesError::ScheduleError(ScheduleError::DelayTooLong(_)) => errno::SCHEDULE_AT_DELAY_TOO_LONG, NodesError::HttpError(message) => return Ok((errno::HTTP_ERROR, Some(message))), + NodesError::OnnxError(message) => return Ok((errno::ONNX_ERROR, Some(message))), NodesError::Internal(ref internal) => match **internal { DBError::Datastore(DatastoreError::Index(IndexError::UniqueConstraintViolation( UniqueConstraintViolation { @@ -431,6 +432,10 @@ macro_rules! abi_funcs { "spacetime_10.4"::datastore_index_scan_point_bsatn, "spacetime_10.4"::datastore_delete_by_index_scan_point_bsatn, + "spacetime_10.5"::onnx_load_model, + "spacetime_10.5"::onnx_run_inference, + "spacetime_10.5"::onnx_close_model, + } $link_async! { diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 310401a42b3..e5523dea10b 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -316,6 +316,7 @@ pub struct WasmModuleHostActor { module: T::InstancePre, common: ModuleCommon, func_names: Arc, + models_dir: Option, } #[derive(thiserror::Error, Debug)] @@ -375,19 +376,21 @@ impl WasmModuleHostActor { func_names }; let uninit_instance = module.instantiate_pre()?; - let instance_env = InstanceEnv::new(mcc.replica_ctx.clone(), mcc.scheduler.clone()); + let models_dir = mcc.data_dir.as_ref().map(|d| d.0.join("models")); + let mut instance_env = InstanceEnv::new(mcc.replica_ctx.clone(), mcc.scheduler.clone()); + instance_env.models_dir = models_dir.clone(); let mut instance = uninit_instance.instantiate(instance_env, &func_names)?; let desc = instance.extract_descriptions()?; // Validate and create a common module rom the raw definition. let common = build_common_module_from_raw(mcc, desc)?; - let func_names = Arc::new(func_names); let module = WasmModuleHostActor { module: uninit_instance, func_names, common, + models_dir, }; let initial_instance = module.make_from_instance(instance); @@ -422,7 +425,8 @@ impl WasmModuleHostActor { pub fn create_instance(&self) -> WasmModuleInstance { let common = &self.common; - let env = InstanceEnv::new(common.replica_ctx().clone(), common.scheduler().clone()); + let mut env = InstanceEnv::new(common.replica_ctx().clone(), common.scheduler().clone()); + env.models_dir = self.models_dir.clone(); // this shouldn't fail, since we already called module.create_instance() // before and it didn't error, and ideally they should be deterministic let mut instance = self diff --git a/crates/core/src/host/wasmtime/wasm_instance_env.rs b/crates/core/src/host/wasmtime/wasm_instance_env.rs index 74a57b35e92..3993937f71b 100644 --- a/crates/core/src/host/wasmtime/wasm_instance_env.rs +++ b/crates/core/src/host/wasmtime/wasm_instance_env.rs @@ -133,6 +133,12 @@ pub(super) struct WasmInstanceEnv { /// A pool of unused allocated chunks that can be reused. // TODO(Centril): consider using this pool for `console_timer_start` and `bytes_sink_write`. chunk_pool: ChunkPool, + + /// Loaded ONNX models, keyed by model handle. + onnx_models: std::collections::HashMap, + + /// Counter for generating ONNX model handles. + next_onnx_model_id: u32, } const STANDARD_BYTES_SINK: u32 = 1; @@ -158,6 +164,8 @@ impl WasmInstanceEnv { timing_spans: Default::default(), call_times: CallTimes::new(), chunk_pool: <_>::default(), + onnx_models: std::collections::HashMap::new(), + next_onnx_model_id: 1, } } @@ -1878,6 +1886,105 @@ impl WasmInstanceEnv { /// - `body_ptr` is NULL or `body_ptr[..body_len]` is not in bounds of WASM memory. /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. /// - `request_ptr[..request_len]` does not contain a valid BSATN-serialized `spacetimedb_lib::http::Request` object. + /// Load an ONNX model by name from the host's model storage. + /// + /// `name_ptr[..name_len]` is a UTF-8 model name. The host resolves this to + /// a `.onnx` file on disk, loads and optimizes it entirely on the host side. + /// The model bytes never enter WASM memory. + /// + /// On success, writes a model handle (u32) to `out` and returns 0. + /// On error, writes a `BytesSource` handle containing a BSATN-encoded error `String` + /// to `out` and returns `ONNX_ERROR`. + pub fn onnx_load_model( + caller: Caller<'_, Self>, + name_ptr: WasmPtr, + name_len: u32, + out: WasmPtr, + ) -> RtResult { + Self::cvt_custom(caller, AbiCall::OnnxLoadModel, |caller| { + let (mem, env) = Self::mem_env(caller); + let name = mem.deref_str(name_ptr, name_len)?; + + match crate::host::onnx::OnnxModel::load_by_name(name, &env.instance_env) { + Ok(model) => { + let handle = env.next_onnx_model_id; + env.next_onnx_model_id = handle.checked_add(1) + .ok_or_else(|| WasmError::Wasm(anyhow!("ONNX model handle overflow")))?; + env.onnx_models.insert(handle, model); + handle.write_to(mem, out)?; + Ok(0u32) + } + Err(err) => { + let err_msg = bsatn::to_vec(&err.to_string()) + .context("Failed to BSATN-serialize ONNX error")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, err_msg.into())?; + bytes_source.0.write_to(mem, out)?; + Ok(errno::ONNX_ERROR.get() as u32) + } + } + }) + } + + /// Run inference on the model identified by `model_handle` + /// with BSATN-encoded input tensors at `input_ptr[..input_len]`. + /// + /// On success, writes a `BytesSource` containing BSATN-encoded output tensors to `out` + /// and returns 0. + /// On error, writes a `BytesSource` containing a BSATN-encoded error `String` to `out` + /// and returns `ONNX_ERROR`. + pub fn onnx_run_inference( + caller: Caller<'_, Self>, + model_handle: u32, + input_ptr: WasmPtr, + input_len: u32, + out: WasmPtr, + ) -> RtResult { + Self::cvt_custom(caller, AbiCall::OnnxRunInference, |caller| { + let (mem, env) = Self::mem_env(caller); + + let model = env.onnx_models.get(&model_handle) + .ok_or(WasmError::Db(NodesError::OnnxError( + format!("No ONNX model with handle {model_handle}") + )))?; + + let input_buf = mem.deref_slice(input_ptr, input_len)?; + let inputs: Vec = + bsatn::from_slice(input_buf).map_err(|err| NodesError::DecodeValue(err))?; + + match model.run(&inputs) { + Ok(outputs) => { + let result = bsatn::to_vec(&outputs) + .context("Failed to BSATN-serialize ONNX output tensors")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, result.into())?; + bytes_source.0.write_to(mem, out)?; + Ok(0u32) + } + Err(err) => { + let err_msg = bsatn::to_vec(&err.to_string()) + .context("Failed to BSATN-serialize ONNX error")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, err_msg.into())?; + bytes_source.0.write_to(mem, out)?; + Ok(errno::ONNX_ERROR.get() as u32) + } + } + }) + } + + /// Free the ONNX model identified by `model_handle`. + pub fn onnx_close_model( + caller: Caller<'_, Self>, + model_handle: u32, + ) -> RtResult { + Self::cvt_custom(caller, AbiCall::OnnxCloseModel, |caller| { + let (_, env) = Self::mem_env(caller); + if env.onnx_models.remove(&model_handle).is_some() { + Ok(0u32) + } else { + Ok(errno::NO_SUCH_MODEL.get() as u32) + } + }) + } + pub fn procedure_http_request<'caller>( caller: Caller<'caller, Self>, (request_ptr, request_len, body_ptr, body_len, out): (WasmPtr, u32, WasmPtr, u32, WasmPtr), diff --git a/crates/core/src/host/wasmtime/wasmtime_module.rs b/crates/core/src/host/wasmtime/wasmtime_module.rs index 48ac0fe80e2..0690120eb16 100644 --- a/crates/core/src/host/wasmtime/wasmtime_module.rs +++ b/crates/core/src/host/wasmtime/wasmtime_module.rs @@ -50,7 +50,7 @@ impl WasmtimeModule { WasmtimeModule { module } } - pub const IMPLEMENTED_ABI: abi::VersionTuple = abi::VersionTuple::new(10, 4); + pub const IMPLEMENTED_ABI: abi::VersionTuple = abi::VersionTuple::new(10, 5); pub(super) fn link_imports(linker: &mut Linker) -> anyhow::Result<()> { const { assert!(WasmtimeModule::IMPLEMENTED_ABI.major == spacetimedb_lib::MODULE_ABI_MAJOR_VERSION) }; diff --git a/crates/core/src/module_host_context.rs b/crates/core/src/module_host_context.rs index 50f8c258a37..d783e5101c6 100644 --- a/crates/core/src/module_host_context.rs +++ b/crates/core/src/module_host_context.rs @@ -1,6 +1,7 @@ use crate::energy::EnergyMonitor; use crate::host::scheduler::Scheduler; use crate::replica_context::ReplicaContext; +use spacetimedb_paths::server::ServerDataDir; use spacetimedb_sats::hash::Hash; use std::sync::Arc; @@ -9,4 +10,5 @@ pub struct ModuleCreationContext { pub scheduler: Scheduler, pub program_hash: Hash, pub energy_monitor: Arc, + pub data_dir: Option>, } diff --git a/crates/lib/src/lib.rs b/crates/lib/src/lib.rs index 2e8b9c08336..0cb30b107ea 100644 --- a/crates/lib/src/lib.rs +++ b/crates/lib/src/lib.rs @@ -17,6 +17,7 @@ pub mod error; mod filterable_value; pub mod http; pub mod identity; +pub mod onnx; pub mod metrics; pub mod operator; pub mod query; diff --git a/crates/lib/src/onnx.rs b/crates/lib/src/onnx.rs new file mode 100644 index 00000000000..dc5055966d9 --- /dev/null +++ b/crates/lib/src/onnx.rs @@ -0,0 +1,26 @@ +//! Types for ONNX inference, used in the ABI between +//! SpacetimeDB host and guest WASM modules. +//! +//! These types are BSATN-encoded for interchange across the WASM boundary. + +use spacetimedb_sats::SpacetimeType; + +/// A tensor for ONNX inference, with shape metadata and flattened f32 data. +/// +/// Data is stored in row-major order (C-order). +/// For example, a 2x3 matrix `[[1,2,3],[4,5,6]]` would have: +/// - `shape: [2, 3]` +/// - `data: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]` +#[derive(Clone, Debug, SpacetimeType)] +#[sats(crate = crate, name = "OnnxTensor")] +pub struct Tensor { + /// The dimensions of the tensor, e.g. `[1, 10]` for a 1x10 matrix. + pub shape: Vec, + /// Flattened f32 data in row-major order. + pub data: Vec, +} + +/// An opaque handle to a loaded ONNX model on the host. +/// +/// Returned by `onnx_load_model` and passed to `onnx_run_inference`. +pub type ModelHandle = u32; diff --git a/crates/primitives/src/errno.rs b/crates/primitives/src/errno.rs index 5c422941715..fdc0a921de9 100644 --- a/crates/primitives/src/errno.rs +++ b/crates/primitives/src/errno.rs @@ -35,6 +35,8 @@ macro_rules! errnos { "ABI call can only be made while within a read-only transaction" ), HTTP_ERROR(21, "The HTTP request failed"), + ONNX_ERROR(22, "ONNX inference failed"), + NO_SUCH_MODEL(23, "The provided ONNX model handle is not valid"), ); }; } From 492dfa55fac5596c51212f16d40d3e4042f86d94 Mon Sep 17 00:00:00 2001 From: definenoob Date: Sun, 8 Mar 2026 07:21:58 -0400 Subject: [PATCH 2/6] updates --- crates/bindings-sys/src/lib.rs | 114 +++--------------- crates/bindings/src/onnx.rs | 50 ++------ crates/core/src/host/mod.rs | 4 +- crates/core/src/host/wasm_common.rs | 4 +- .../src/host/wasmtime/wasm_instance_env.rs | 93 +++++--------- crates/lib/src/onnx.rs | 5 - crates/primitives/src/errno.rs | 1 - 7 files changed, 54 insertions(+), 217 deletions(-) diff --git a/crates/bindings-sys/src/lib.rs b/crates/bindings-sys/src/lib.rs index 4616cf0a434..d4d9102d2a9 100644 --- a/crates/bindings-sys/src/lib.rs +++ b/crates/bindings-sys/src/lib.rs @@ -867,75 +867,39 @@ pub mod raw { #[link(wasm_import_module = "spacetime_10.5")] unsafe extern "C" { - /// Loads an ONNX model by name from the host's model storage. + /// Runs ONNX inference on a model identified by name. /// /// `name_ptr[..name_len]` is a UTF-8 model name (e.g. `"bot_brain"`). /// The host resolves this to a `.onnx` file on its filesystem, - /// loads and optimizes the model entirely on the host side. - /// The model bytes never enter WASM memory. + /// loads and caches the model on first use. Model bytes never enter WASM memory. /// - /// On success, writes a model handle (u32) to `out[0]` and returns 0. - /// The model handle can be used with [`onnx_run_inference`] and freed with [`onnx_close_model`]. + /// `input_ptr[..input_len]` should contain a BSATN-encoded `Vec`. + /// + /// On success, a [`BytesSource`] is written to `out[0]` containing a BSATN-encoded + /// `Vec` with the inference output, and this function returns 0. /// /// # Traps /// /// Traps if: /// - `name_ptr` is NULL or `name_ptr[..name_len]` is not in bounds of WASM memory. /// - `name_ptr[..name_len]` is not valid UTF-8. + /// - `input_ptr` is NULL or `input_ptr[..input_len]` is not in bounds of WASM memory. /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. /// /// # Errors /// /// Returns an error: /// - /// - `ONNX_ERROR` if the model could not be found or loaded. + /// - `ONNX_ERROR` if the model could not be found, loaded, or inference failed. /// In this case, a [`BytesSource`] containing a BSATN-encoded error message `String` /// is written to `out[0]`. - pub fn onnx_load_model( + pub fn onnx_run( name_ptr: *const u8, name_len: u32, - out: *mut u32, - ) -> u16; - - /// Runs inference on the model identified by `model_handle` - /// with BSATN-encoded input tensors at `input_ptr[..input_len]`. - /// - /// `input_ptr[..input_len]` should contain a BSATN-encoded `Vec`. - /// - /// On success, a [`BytesSource`] is written to `out[0]` containing a BSATN-encoded - /// `Vec` with the inference output, and this function returns 0. - /// - /// # Traps - /// - /// Traps if: - /// - `input_ptr` is NULL or `input_ptr[..input_len]` is not in bounds of WASM memory. - /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. - /// - /// # Errors - /// - /// Returns an error: - /// - /// - `NO_SUCH_MODEL` if `model_handle` does not refer to a loaded model. - /// - `BSATN_DECODE_ERROR` if the input tensors could not be decoded. - /// - `ONNX_ERROR` if inference failed. In this case, a [`BytesSource`] containing - /// a BSATN-encoded error message `String` is written to `out[0]`. - pub fn onnx_run_inference( - model_handle: u32, input_ptr: *const u8, input_len: u32, out: *mut u32, ) -> u16; - - /// Frees the ONNX model identified by `model_handle`. - /// - /// After this call, the handle is no longer valid. - /// - /// # Errors - /// - /// Returns an error: - /// - /// - `NO_SUCH_MODEL` if `model_handle` does not refer to a loaded model. - pub fn onnx_close_model(model_handle: u32) -> u16; } /// What strategy does the database index use? @@ -1704,53 +1668,21 @@ pub mod procedure { pub mod onnx { use super::raw; - /// Load an ONNX model by name from the host's model storage. - /// - /// The host resolves the name to a `.onnx` file on its filesystem, - /// loads and optimizes the model entirely on the host side. - /// - /// On success, returns `Ok(model_handle)`. - /// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message `String`. - #[inline] - pub fn load_model(name: &str) -> Result { - // The host writes either a model handle (on success) or a BytesSource handle (on error) - // into `out`. Both are u32. BytesSource is #[repr(transparent)] over u32. - let mut out = [raw::BytesSource::INVALID; 1]; - - let res = unsafe { - super::raw::onnx_load_model( - name.as_ptr(), - name.len() as u32, - out.as_mut_ptr().cast(), - ) - }; - - match super::Errno::from_code(res) { - // Success: out[0] is a model handle as raw u32 bits. - // Safety: BytesSource is #[repr(transparent)] over u32. - None => Ok(unsafe { core::mem::transmute::(out[0]) }), - // ONNX_ERROR: out[0] is a BytesSource with the error message. - Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]), - Some(errno) => panic!("{errno}"), - } - } - - /// Run inference on a loaded model with BSATN-encoded input tensors. + /// Run ONNX inference on a named model with BSATN-encoded input tensors. /// + /// The host loads and caches the model on first use. /// `input_bsatn` should be a BSATN-encoded `Vec`. /// /// On success, returns `Ok(bytes_source)` containing BSATN-encoded output tensors. /// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message. #[inline] - pub fn run_inference( - model_handle: u32, - input_bsatn: &[u8], - ) -> Result { + pub fn run(name: &str, input_bsatn: &[u8]) -> Result { let mut out = [raw::BytesSource::INVALID; 1]; let res = unsafe { - super::raw::onnx_run_inference( - model_handle, + super::raw::onnx_run( + name.as_ptr(), + name.len() as u32, input_bsatn.as_ptr(), input_bsatn.len() as u32, out.as_mut_ptr().cast(), @@ -1760,22 +1692,6 @@ pub mod onnx { match super::Errno::from_code(res) { None => Ok(out[0]), Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]), - Some(errno) if errno == super::Errno::NO_SUCH_MODEL => { - panic!("ONNX model handle {model_handle} is not valid") - } - Some(errno) => panic!("{errno}"), - } - } - - /// Close a loaded ONNX model, freeing its resources. - #[inline] - pub fn close_model(model_handle: u32) { - let res = unsafe { super::raw::onnx_close_model(model_handle) }; - match super::Errno::from_code(res) { - None => {} - Some(errno) if errno == super::Errno::NO_SUCH_MODEL => { - panic!("ONNX model handle {model_handle} is not valid") - } Some(errno) => panic!("{errno}"), } } diff --git a/crates/bindings/src/onnx.rs b/crates/bindings/src/onnx.rs index ae16cc45a6a..8a1739f904e 100644 --- a/crates/bindings/src/onnx.rs +++ b/crates/bindings/src/onnx.rs @@ -1,22 +1,21 @@ //! ONNX inference support for SpacetimeDB modules. //! -//! Load an ONNX model by name and run inference from within reducers or procedures. +//! Run ONNX model inference from within reducers or procedures. //! Models are stored on the host filesystem — the model bytes never enter WASM memory. +//! Models are cached on the host after first load. //! //! # Example //! //! ```no_run -//! # use spacetimedb::{reducer, ReducerContext, onnx::{OnnxClient, Tensor, ModelHandle}}; +//! # use spacetimedb::{reducer, ReducerContext, onnx::{OnnxClient, Tensor}}; //! // In a reducer: //! # #[reducer] //! # fn my_reducer(ctx: &ReducerContext) { -//! // Load a model by name — the host resolves "bot_brain" to a .onnx file on disk. -//! let model = ctx.onnx.load("bot_brain").expect("Failed to load model"); //! let input = vec![Tensor { //! shape: vec![1, 10], //! data: vec![0.0; 10], //! }]; -//! let output = ctx.onnx.run(&model, &input).expect("Inference failed"); +//! let output = ctx.onnx.run("bot_brain", &input).expect("Inference failed"); //! log::info!("Output: {:?}", output[0].data); //! # } //! ``` @@ -26,18 +25,6 @@ use spacetimedb_lib::bsatn; pub use spacetimedb_lib::onnx::Tensor; -/// An opaque handle to a loaded ONNX model on the host. -/// -/// Obtained via [`OnnxClient::load`] and used with [`OnnxClient::run`]. -/// The model is freed when this handle is dropped. -pub struct ModelHandle(u32); - -impl Drop for ModelHandle { - fn drop(&mut self) { - spacetimedb_bindings_sys::onnx::close_model(self.0); - } -} - /// Client for performing ONNX inference. /// /// Access from within reducers via [`ReducerContext::onnx`](crate::ReducerContext) @@ -46,35 +33,18 @@ impl Drop for ModelHandle { pub struct OnnxClient {} impl OnnxClient { - /// Load an ONNX model by name from the host's model storage. - /// - /// The host resolves the name to a `.onnx` file on its filesystem - /// (e.g. in the database's `models/` directory), then loads and optimizes it - /// entirely on the host side. The model bytes never enter WASM memory. + /// Run inference on a named ONNX model. /// - /// The returned [`ModelHandle`] can be used with [`OnnxClient::run`] for inference. - /// The model is automatically freed when the handle is dropped. - pub fn load(&self, model_name: &str) -> Result { - match spacetimedb_bindings_sys::onnx::load_model(model_name) { - Ok(handle) => Ok(ModelHandle(handle)), - Err(err_source) => { - let message = read_bytes_source_as::(err_source); - Err(Error { message }) - } - } - } - - /// Run inference on a loaded model. + /// The host resolves `model_name` to a `.onnx` file on its filesystem, + /// loads and caches it on first use, then runs inference with the given inputs. + /// Model bytes never enter WASM memory — only tensor data crosses the boundary. /// /// `inputs` are the input tensors for the model, in the order expected by the model's input nodes. /// Returns the output tensors from the model. - /// - /// Inference runs entirely on the host in native Rust — only the input/output tensor data - /// crosses the WASM boundary. - pub fn run(&self, model: &ModelHandle, inputs: &[Tensor]) -> Result, Error> { + pub fn run(&self, model_name: &str, inputs: &[Tensor]) -> Result, Error> { let input_bsatn = bsatn::to_vec(inputs).expect("Failed to BSATN-serialize input tensors"); - match spacetimedb_bindings_sys::onnx::run_inference(model.0, &input_bsatn) { + match spacetimedb_bindings_sys::onnx::run(model_name, &input_bsatn) { Ok(output_source) => { let output = read_bytes_source_as::>(output_source); Ok(output) diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index 32169711d7d..edb121274d0 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -196,7 +196,5 @@ pub enum AbiCall { ProcedureAbortMutTransaction, ProcedureHttpRequest, - OnnxLoadModel, - OnnxRunInference, - OnnxCloseModel, + OnnxRun, } diff --git a/crates/core/src/host/wasm_common.rs b/crates/core/src/host/wasm_common.rs index 25c6f9db977..788ccf68ec8 100644 --- a/crates/core/src/host/wasm_common.rs +++ b/crates/core/src/host/wasm_common.rs @@ -432,9 +432,7 @@ macro_rules! abi_funcs { "spacetime_10.4"::datastore_index_scan_point_bsatn, "spacetime_10.4"::datastore_delete_by_index_scan_point_bsatn, - "spacetime_10.5"::onnx_load_model, - "spacetime_10.5"::onnx_run_inference, - "spacetime_10.5"::onnx_close_model, + "spacetime_10.5"::onnx_run, } diff --git a/crates/core/src/host/wasmtime/wasm_instance_env.rs b/crates/core/src/host/wasmtime/wasm_instance_env.rs index 3993937f71b..18ee9da3e98 100644 --- a/crates/core/src/host/wasmtime/wasm_instance_env.rs +++ b/crates/core/src/host/wasmtime/wasm_instance_env.rs @@ -134,11 +134,8 @@ pub(super) struct WasmInstanceEnv { // TODO(Centril): consider using this pool for `console_timer_start` and `bytes_sink_write`. chunk_pool: ChunkPool, - /// Loaded ONNX models, keyed by model handle. - onnx_models: std::collections::HashMap, - - /// Counter for generating ONNX model handles. - next_onnx_model_id: u32, + /// Cached ONNX models, keyed by model name. + onnx_models: std::collections::HashMap, } const STANDARD_BYTES_SINK: u32 = 1; @@ -165,7 +162,6 @@ impl WasmInstanceEnv { call_times: CallTimes::new(), chunk_pool: <_>::default(), onnx_models: std::collections::HashMap::new(), - next_onnx_model_id: 1, } } @@ -1887,65 +1883,45 @@ impl WasmInstanceEnv { /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. /// - `request_ptr[..request_len]` does not contain a valid BSATN-serialized `spacetimedb_lib::http::Request` object. /// Load an ONNX model by name from the host's model storage. + /// Run ONNX inference on a model identified by name. /// /// `name_ptr[..name_len]` is a UTF-8 model name. The host resolves this to - /// a `.onnx` file on disk, loads and optimizes it entirely on the host side. - /// The model bytes never enter WASM memory. - /// - /// On success, writes a model handle (u32) to `out` and returns 0. - /// On error, writes a `BytesSource` handle containing a BSATN-encoded error `String` - /// to `out` and returns `ONNX_ERROR`. - pub fn onnx_load_model( - caller: Caller<'_, Self>, - name_ptr: WasmPtr, - name_len: u32, - out: WasmPtr, - ) -> RtResult { - Self::cvt_custom(caller, AbiCall::OnnxLoadModel, |caller| { - let (mem, env) = Self::mem_env(caller); - let name = mem.deref_str(name_ptr, name_len)?; - - match crate::host::onnx::OnnxModel::load_by_name(name, &env.instance_env) { - Ok(model) => { - let handle = env.next_onnx_model_id; - env.next_onnx_model_id = handle.checked_add(1) - .ok_or_else(|| WasmError::Wasm(anyhow!("ONNX model handle overflow")))?; - env.onnx_models.insert(handle, model); - handle.write_to(mem, out)?; - Ok(0u32) - } - Err(err) => { - let err_msg = bsatn::to_vec(&err.to_string()) - .context("Failed to BSATN-serialize ONNX error")?; - let bytes_source = WasmInstanceEnv::create_bytes_source(env, err_msg.into())?; - bytes_source.0.write_to(mem, out)?; - Ok(errno::ONNX_ERROR.get() as u32) - } - } - }) - } - - /// Run inference on the model identified by `model_handle` - /// with BSATN-encoded input tensors at `input_ptr[..input_len]`. + /// a `.onnx` file on disk, loads and caches it on first use. + /// `input_ptr[..input_len]` contains BSATN-encoded input tensors. /// /// On success, writes a `BytesSource` containing BSATN-encoded output tensors to `out` /// and returns 0. /// On error, writes a `BytesSource` containing a BSATN-encoded error `String` to `out` /// and returns `ONNX_ERROR`. - pub fn onnx_run_inference( + pub fn onnx_run( caller: Caller<'_, Self>, - model_handle: u32, + name_ptr: WasmPtr, + name_len: u32, input_ptr: WasmPtr, input_len: u32, out: WasmPtr, ) -> RtResult { - Self::cvt_custom(caller, AbiCall::OnnxRunInference, |caller| { + Self::cvt_custom(caller, AbiCall::OnnxRun, |caller| { let (mem, env) = Self::mem_env(caller); + let name = mem.deref_str(name_ptr, name_len)?.to_owned(); + + // Load and cache the model on first use. + if !env.onnx_models.contains_key(&name) { + match crate::host::onnx::OnnxModel::load_by_name(&name, &env.instance_env) { + Ok(model) => { + env.onnx_models.insert(name.clone(), model); + } + Err(err) => { + let err_msg = bsatn::to_vec(&err.to_string()) + .context("Failed to BSATN-serialize ONNX error")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, err_msg.into())?; + bytes_source.0.write_to(mem, out)?; + return Ok(errno::ONNX_ERROR.get() as u32); + } + } + } - let model = env.onnx_models.get(&model_handle) - .ok_or(WasmError::Db(NodesError::OnnxError( - format!("No ONNX model with handle {model_handle}") - )))?; + let model = env.onnx_models.get(&name).unwrap(); let input_buf = mem.deref_slice(input_ptr, input_len)?; let inputs: Vec = @@ -1970,21 +1946,6 @@ impl WasmInstanceEnv { }) } - /// Free the ONNX model identified by `model_handle`. - pub fn onnx_close_model( - caller: Caller<'_, Self>, - model_handle: u32, - ) -> RtResult { - Self::cvt_custom(caller, AbiCall::OnnxCloseModel, |caller| { - let (_, env) = Self::mem_env(caller); - if env.onnx_models.remove(&model_handle).is_some() { - Ok(0u32) - } else { - Ok(errno::NO_SUCH_MODEL.get() as u32) - } - }) - } - pub fn procedure_http_request<'caller>( caller: Caller<'caller, Self>, (request_ptr, request_len, body_ptr, body_len, out): (WasmPtr, u32, WasmPtr, u32, WasmPtr), diff --git a/crates/lib/src/onnx.rs b/crates/lib/src/onnx.rs index dc5055966d9..2b50e8888a0 100644 --- a/crates/lib/src/onnx.rs +++ b/crates/lib/src/onnx.rs @@ -19,8 +19,3 @@ pub struct Tensor { /// Flattened f32 data in row-major order. pub data: Vec, } - -/// An opaque handle to a loaded ONNX model on the host. -/// -/// Returned by `onnx_load_model` and passed to `onnx_run_inference`. -pub type ModelHandle = u32; diff --git a/crates/primitives/src/errno.rs b/crates/primitives/src/errno.rs index fdc0a921de9..320b597227f 100644 --- a/crates/primitives/src/errno.rs +++ b/crates/primitives/src/errno.rs @@ -36,7 +36,6 @@ macro_rules! errnos { ), HTTP_ERROR(21, "The HTTP request failed"), ONNX_ERROR(22, "ONNX inference failed"), - NO_SUCH_MODEL(23, "The provided ONNX model handle is not valid"), ); }; } From d9192238773c12302bd978bcea7caec70f3c149e Mon Sep 17 00:00:00 2001 From: definenoob Date: Sun, 8 Mar 2026 07:44:03 -0400 Subject: [PATCH 3/6] feature --- crates/bindings-sys/Cargo.toml | 1 + crates/bindings-sys/src/lib.rs | 2 ++ crates/bindings/Cargo.toml | 1 + crates/bindings/src/lib.rs | 6 ++++++ crates/core/Cargo.toml | 3 ++- crates/core/src/error.rs | 1 + crates/core/src/host/host_controller.rs | 11 ++++++++-- crates/core/src/host/instance_env.rs | 2 ++ crates/core/src/host/mod.rs | 2 ++ crates/core/src/host/wasm_common.rs | 3 +-- .../src/host/wasm_common/module_host_actor.rs | 21 +++++++++++++++---- .../src/host/wasmtime/wasm_instance_env.rs | 4 +++- .../core/src/host/wasmtime/wasmtime_module.rs | 2 ++ crates/core/src/module_host_context.rs | 2 ++ 14 files changed, 51 insertions(+), 10 deletions(-) diff --git a/crates/bindings-sys/Cargo.toml b/crates/bindings-sys/Cargo.toml index ee8fb474b6d..752b6fa6968 100644 --- a/crates/bindings-sys/Cargo.toml +++ b/crates/bindings-sys/Cargo.toml @@ -12,6 +12,7 @@ bench = false [features] unstable = [] +onnx = [] [dependencies] spacetimedb-primitives.workspace = true diff --git a/crates/bindings-sys/src/lib.rs b/crates/bindings-sys/src/lib.rs index d4d9102d2a9..3fbc4799afc 100644 --- a/crates/bindings-sys/src/lib.rs +++ b/crates/bindings-sys/src/lib.rs @@ -865,6 +865,7 @@ pub mod raw { ) -> u16; } + #[cfg(feature = "onnx")] #[link(wasm_import_module = "spacetime_10.5")] unsafe extern "C" { /// Runs ONNX inference on a model identified by name. @@ -1665,6 +1666,7 @@ pub mod procedure { } /// ONNX inference operations, available from both reducers and procedures. +#[cfg(feature = "onnx")] pub mod onnx { use super::raw; diff --git a/crates/bindings/Cargo.toml b/crates/bindings/Cargo.toml index 54c3bd3c74e..0bdc88816c6 100644 --- a/crates/bindings/Cargo.toml +++ b/crates/bindings/Cargo.toml @@ -18,6 +18,7 @@ default = ["rand"] rand = ["rand08"] rand08 = ["dep:rand08", "dep:getrandom02"] unstable = ["spacetimedb-bindings-sys/unstable"] +onnx = ["spacetimedb-bindings-sys/onnx"] [dependencies] spacetimedb-bindings-sys.workspace = true diff --git a/crates/bindings/src/lib.rs b/crates/bindings/src/lib.rs index e9fc404fa5a..b99384268a9 100644 --- a/crates/bindings/src/lib.rs +++ b/crates/bindings/src/lib.rs @@ -11,6 +11,7 @@ mod client_visibility_filter; #[cfg(feature = "unstable")] pub mod http; pub mod log_stopwatch; +#[cfg(feature = "onnx")] pub mod onnx; mod logger; #[cfg(feature = "rand08")] @@ -1007,6 +1008,7 @@ pub struct ReducerContext { pub db: Local, /// Methods for performing ONNX inference. + #[cfg(feature = "onnx")] pub onnx: crate::onnx::OnnxClient, #[cfg(feature = "rand08")] @@ -1022,6 +1024,7 @@ impl ReducerContext { pub fn __dummy() -> Self { Self { db: Local {}, + #[cfg(feature = "onnx")] onnx: crate::onnx::OnnxClient {}, sender: Identity::__dummy(), timestamp: Timestamp::UNIX_EPOCH, @@ -1038,6 +1041,7 @@ impl ReducerContext { fn new(db: Local, sender: Identity, connection_id: Option, timestamp: Timestamp) -> Self { Self { db, + #[cfg(feature = "onnx")] onnx: crate::onnx::OnnxClient {}, sender, timestamp, @@ -1187,6 +1191,7 @@ pub struct ProcedureContext { pub http: crate::http::HttpClient, /// Methods for performing ONNX inference. + #[cfg(feature = "onnx")] pub onnx: crate::onnx::OnnxClient, // TODO: Change rng? // Complex and requires design because we may want procedure RNG to behave differently from reducer RNG, @@ -1208,6 +1213,7 @@ impl ProcedureContext { timestamp, connection_id, http: http::HttpClient {}, + #[cfg(feature = "onnx")] onnx: crate::onnx::OnnxClient {}, #[cfg(feature = "rand08")] rng: std::cell::OnceCell::new(), diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index bf3737b21e9..eacbe0117fe 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -120,7 +120,7 @@ url.workspace = true urlencoding.workspace = true uuid.workspace = true v8.workspace = true -tract-onnx.workspace = true +tract-onnx = { workspace = true, optional = true } wasmtime.workspace = true wasmtime-internal-fiber.workspace = true jwks.workspace = true @@ -151,6 +151,7 @@ perfmap = [] # Disables core pinning no-core-pinning = [] no-job-core-pinning = [] +onnx = ["dep:tract-onnx"] [dev-dependencies] spacetimedb-lib = { path = "../lib", features = ["proptest", "test"] } diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index fc3525e0b01..60dd670d344 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -306,6 +306,7 @@ pub enum NodesError { ScheduleError(#[source] ScheduleError), #[error("HTTP request failed: {0}")] HttpError(String), + #[cfg(feature = "onnx")] #[error("ONNX inference failed: {0}")] OnnxError(String), } diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index 74222faf7d3..c3683601339 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -453,6 +453,7 @@ impl HostController { this.energy_monitor.clone(), this.unregister_fn(replica_id), this.db_cores.take(), + #[cfg(feature = "onnx")] Some(this.data_dir.clone()), ) .await?; @@ -710,7 +711,7 @@ async fn make_module_host( energy_monitor: Arc, unregister: impl Fn() + Send + Sync + 'static, core: AllocatedJobCore, - data_dir: Option>, + #[cfg(feature = "onnx")] data_dir: Option>, ) -> anyhow::Result<(Program, ModuleHost)> { // `make_actor` is blocking, as it needs to compile the wasm to native code, // which may be computationally expensive - sometimes up to 1s for a large module. @@ -724,6 +725,7 @@ async fn make_module_host( scheduler, program_hash: program.hash, energy_monitor, + #[cfg(feature = "onnx")] data_dir, }; @@ -773,6 +775,7 @@ struct ModuleLauncher { runtimes: Arc, core: AllocatedJobCore, bsatn_rlb_pool: BsatnRowListBuilderPool, + #[cfg(feature = "onnx")] data_dir: Option>, } @@ -805,6 +808,7 @@ impl ModuleLauncher { self.energy_monitor, self.on_panic, self.core, + #[cfg(feature = "onnx")] self.data_dir, ) .await?; @@ -993,6 +997,7 @@ impl Host { runtimes: runtimes.clone(), core: host_controller.db_cores.take(), bsatn_rlb_pool: bsatn_rlb_pool.clone(), + #[cfg(feature = "onnx")] data_dir: Some(data_dir.clone()), } .launch_module() @@ -1090,6 +1095,7 @@ impl Host { runtimes: runtimes.clone(), core, bsatn_rlb_pool, + #[cfg(feature = "onnx")] data_dir: None, } .launch_module() @@ -1117,7 +1123,7 @@ impl Host { energy_monitor: Arc, on_panic: impl Fn() + Send + Sync + 'static, core: AllocatedJobCore, - data_dir: Option>, + #[cfg(feature = "onnx")] data_dir: Option>, ) -> anyhow::Result { let replica_ctx = &self.replica_ctx; let (scheduler, scheduler_starter) = Scheduler::open(self.replica_ctx.relational_db.clone()); @@ -1130,6 +1136,7 @@ impl Host { energy_monitor, on_panic, core, + #[cfg(feature = "onnx")] data_dir, ) .await?; diff --git a/crates/core/src/host/instance_env.rs b/crates/core/src/host/instance_env.rs index 338f270b87f..ca861777aec 100644 --- a/crates/core/src/host/instance_env.rs +++ b/crates/core/src/host/instance_env.rs @@ -55,6 +55,7 @@ pub struct InstanceEnv { procedure_last_tx_offset: Option, /// Directory on the host filesystem where ONNX model files are stored. /// Set during module initialization if model storage is configured. + #[cfg(feature = "onnx")] pub models_dir: Option, } @@ -240,6 +241,7 @@ impl InstanceEnv { func_name: None, in_anon_tx: false, procedure_last_tx_offset: None, + #[cfg(feature = "onnx")] models_dir: None, } } diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index edb121274d0..cf2771324d6 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -13,6 +13,7 @@ use spacetimedb_schema::def::ModuleDef; mod disk_storage; mod host_controller; mod module_common; +#[cfg(feature = "onnx")] pub mod onnx; #[allow(clippy::too_many_arguments)] pub mod module_host; @@ -196,5 +197,6 @@ pub enum AbiCall { ProcedureAbortMutTransaction, ProcedureHttpRequest, + #[cfg(feature = "onnx")] OnnxRun, } diff --git a/crates/core/src/host/wasm_common.rs b/crates/core/src/host/wasm_common.rs index 788ccf68ec8..d22b96f5a46 100644 --- a/crates/core/src/host/wasm_common.rs +++ b/crates/core/src/host/wasm_common.rs @@ -358,6 +358,7 @@ pub fn err_to_errno(err: NodesError) -> Result<(NonZeroU16, Option), Nod NodesError::IndexCannotSeekRange => errno::WRONG_INDEX_ALGO, NodesError::ScheduleError(ScheduleError::DelayTooLong(_)) => errno::SCHEDULE_AT_DELAY_TOO_LONG, NodesError::HttpError(message) => return Ok((errno::HTTP_ERROR, Some(message))), + #[cfg(feature = "onnx")] NodesError::OnnxError(message) => return Ok((errno::ONNX_ERROR, Some(message))), NodesError::Internal(ref internal) => match **internal { DBError::Datastore(DatastoreError::Index(IndexError::UniqueConstraintViolation( @@ -432,8 +433,6 @@ macro_rules! abi_funcs { "spacetime_10.4"::datastore_index_scan_point_bsatn, "spacetime_10.4"::datastore_delete_by_index_scan_point_bsatn, - "spacetime_10.5"::onnx_run, - } $link_async! { diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index e5523dea10b..20bc46d63e0 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -316,6 +316,7 @@ pub struct WasmModuleHostActor { module: T::InstancePre, common: ModuleCommon, func_names: Arc, + #[cfg(feature = "onnx")] models_dir: Option, } @@ -376,9 +377,15 @@ impl WasmModuleHostActor { func_names }; let uninit_instance = module.instantiate_pre()?; + #[cfg(feature = "onnx")] let models_dir = mcc.data_dir.as_ref().map(|d| d.0.join("models")); - let mut instance_env = InstanceEnv::new(mcc.replica_ctx.clone(), mcc.scheduler.clone()); - instance_env.models_dir = models_dir.clone(); + let instance_env = InstanceEnv::new(mcc.replica_ctx.clone(), mcc.scheduler.clone()); + #[cfg(feature = "onnx")] + let instance_env = { + let mut env = instance_env; + env.models_dir = models_dir.clone(); + env + }; let mut instance = uninit_instance.instantiate(instance_env, &func_names)?; let desc = instance.extract_descriptions()?; @@ -390,6 +397,7 @@ impl WasmModuleHostActor { module: uninit_instance, func_names, common, + #[cfg(feature = "onnx")] models_dir, }; let initial_instance = module.make_from_instance(instance); @@ -425,8 +433,13 @@ impl WasmModuleHostActor { pub fn create_instance(&self) -> WasmModuleInstance { let common = &self.common; - let mut env = InstanceEnv::new(common.replica_ctx().clone(), common.scheduler().clone()); - env.models_dir = self.models_dir.clone(); + let env = InstanceEnv::new(common.replica_ctx().clone(), common.scheduler().clone()); + #[cfg(feature = "onnx")] + let env = { + let mut env = env; + env.models_dir = self.models_dir.clone(); + env + }; // this shouldn't fail, since we already called module.create_instance() // before and it didn't error, and ideally they should be deterministic let mut instance = self diff --git a/crates/core/src/host/wasmtime/wasm_instance_env.rs b/crates/core/src/host/wasmtime/wasm_instance_env.rs index 18ee9da3e98..ce8c55b6f0a 100644 --- a/crates/core/src/host/wasmtime/wasm_instance_env.rs +++ b/crates/core/src/host/wasmtime/wasm_instance_env.rs @@ -135,6 +135,7 @@ pub(super) struct WasmInstanceEnv { chunk_pool: ChunkPool, /// Cached ONNX models, keyed by model name. + #[cfg(feature = "onnx")] onnx_models: std::collections::HashMap, } @@ -161,6 +162,7 @@ impl WasmInstanceEnv { timing_spans: Default::default(), call_times: CallTimes::new(), chunk_pool: <_>::default(), + #[cfg(feature = "onnx")] onnx_models: std::collections::HashMap::new(), } } @@ -1882,7 +1884,7 @@ impl WasmInstanceEnv { /// - `body_ptr` is NULL or `body_ptr[..body_len]` is not in bounds of WASM memory. /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. /// - `request_ptr[..request_len]` does not contain a valid BSATN-serialized `spacetimedb_lib::http::Request` object. - /// Load an ONNX model by name from the host's model storage. + #[cfg(feature = "onnx")] /// Run ONNX inference on a model identified by name. /// /// `name_ptr[..name_len]` is a UTF-8 model name. The host resolves this to diff --git a/crates/core/src/host/wasmtime/wasmtime_module.rs b/crates/core/src/host/wasmtime/wasmtime_module.rs index 0690120eb16..3b93955c90d 100644 --- a/crates/core/src/host/wasmtime/wasmtime_module.rs +++ b/crates/core/src/host/wasmtime/wasmtime_module.rs @@ -67,6 +67,8 @@ impl WasmtimeModule { } } abi_funcs!(link_functions, link_async_functions); + #[cfg(feature = "onnx")] + linker.func_wrap("spacetime_10.5", "onnx_run", WasmInstanceEnv::onnx_run)?; Ok(()) } } diff --git a/crates/core/src/module_host_context.rs b/crates/core/src/module_host_context.rs index d783e5101c6..5e2095b2dcf 100644 --- a/crates/core/src/module_host_context.rs +++ b/crates/core/src/module_host_context.rs @@ -1,6 +1,7 @@ use crate::energy::EnergyMonitor; use crate::host::scheduler::Scheduler; use crate::replica_context::ReplicaContext; +#[cfg(feature = "onnx")] use spacetimedb_paths::server::ServerDataDir; use spacetimedb_sats::hash::Hash; use std::sync::Arc; @@ -10,5 +11,6 @@ pub struct ModuleCreationContext { pub scheduler: Scheduler, pub program_hash: Hash, pub energy_monitor: Arc, + #[cfg(feature = "onnx")] pub data_dir: Option>, } From a8a41896845dd897c8c2aaea4745918df520b718 Mon Sep 17 00:00:00 2001 From: definenoob Date: Sun, 8 Mar 2026 21:53:08 -0400 Subject: [PATCH 4/6] basic tests --- Cargo.lock | 1 + crates/core/Cargo.toml | 1 + crates/core/src/host/onnx.rs | 157 ++++++++++++++++++++++++++++++++++- 3 files changed, 158 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index fd7cc85075f..afc785f9b44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8104,6 +8104,7 @@ dependencies = [ "prometheus", "proptest", "proptest-derive", + "prost 0.11.9", "rand 0.9.2", "rayon", "rayon-core", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index eacbe0117fe..483327742f8 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -170,6 +170,7 @@ pretty_assertions.workspace = true jsonwebtoken.workspace = true axum.workspace = true fs_extra.workspace = true +prost = "0.11" [lints] workspace = true diff --git a/crates/core/src/host/onnx.rs b/crates/core/src/host/onnx.rs index 772d6ea6f4d..69f825737ab 100644 --- a/crates/core/src/host/onnx.rs +++ b/crates/core/src/host/onnx.rs @@ -48,7 +48,8 @@ impl OnnxModel { } /// Load an ONNX model from raw bytes. - fn load_from_bytes(model_bytes: &[u8]) -> Result { + #[cfg_attr(test, allow(dead_code))] + pub(crate) fn load_from_bytes(model_bytes: &[u8]) -> Result { let model = tract_onnx::onnx() .model_for_read(&mut std::io::Cursor::new(model_bytes)) .map_err(|e| OnnxError(format!("Failed to parse ONNX model: {e}")))? @@ -109,3 +110,157 @@ impl std::fmt::Display for OnnxError { } impl std::error::Error for OnnxError {} + +#[cfg(test)] +mod tests { + use super::*; + use prost::Message; + use tract_onnx::pb; + + /// Build a minimal ONNX model as raw bytes using protobuf types. + /// `op_type` is the ONNX operator (e.g. "Add", "Relu", "Identity"). + /// `n_inputs` is the number of inputs the operator expects. + fn build_onnx_model(op_type: &str, n_inputs: usize) -> Vec { + let input_names: Vec = (0..n_inputs).map(|i| format!("input_{i}")).collect(); + let inputs: Vec = input_names + .iter() + .map(|name| pb::ValueInfoProto { + name: name.clone(), + r#type: Some(pb::TypeProto { + denotation: String::new(), + value: Some(pb::type_proto::Value::TensorType(pb::type_proto::Tensor { + elem_type: 1, // FLOAT + shape: Some(pb::TensorShapeProto { + dim: vec![ + pb::tensor_shape_proto::Dimension { + denotation: String::new(), + value: Some(pb::tensor_shape_proto::dimension::Value::DimValue(1)), + }, + pb::tensor_shape_proto::Dimension { + denotation: String::new(), + value: Some(pb::tensor_shape_proto::dimension::Value::DimValue(4)), + }, + ], + }), + })), + }), + doc_string: String::new(), + }) + .collect(); + + let output = pb::ValueInfoProto { + name: "output".into(), + r#type: Some(pb::TypeProto { + denotation: String::new(), + value: Some(pb::type_proto::Value::TensorType(pb::type_proto::Tensor { + elem_type: 1, + shape: Some(pb::TensorShapeProto { + dim: vec![ + pb::tensor_shape_proto::Dimension { + denotation: String::new(), + value: Some(pb::tensor_shape_proto::dimension::Value::DimValue(1)), + }, + pb::tensor_shape_proto::Dimension { + denotation: String::new(), + value: Some(pb::tensor_shape_proto::dimension::Value::DimValue(4)), + }, + ], + }), + })), + }), + doc_string: String::new(), + }; + + let node = pb::NodeProto { + input: input_names, + output: vec!["output".into()], + name: "node_0".into(), + op_type: op_type.into(), + domain: String::new(), + attribute: vec![], + doc_string: String::new(), + }; + + let graph = pb::GraphProto { + name: "test_graph".into(), + node: vec![node], + input: inputs.clone(), + output: vec![output], + initializer: vec![], + sparse_initializer: vec![], + doc_string: String::new(), + value_info: vec![], + quantization_annotation: vec![], + }; + + let model = pb::ModelProto { + ir_version: 7, + opset_import: vec![pb::OperatorSetIdProto { + domain: String::new(), + version: 13, + }], + producer_name: "spacetimedb-test".into(), + graph: Some(graph), + ..Default::default() + }; + + model.encode_to_vec() + } + + #[test] + fn load_and_run_add_model() { + let model_bytes = build_onnx_model("Add", 2); + let model = OnnxModel::load_from_bytes(&model_bytes).expect("Failed to load model"); + + let a = StdbTensor { + shape: vec![1, 4], + data: vec![1.0, 2.0, 3.0, 4.0], + }; + let b = StdbTensor { + shape: vec![1, 4], + data: vec![10.0, 20.0, 30.0, 40.0], + }; + + let outputs = model.run(&[a, b]).expect("Inference failed"); + assert_eq!(outputs.len(), 1); + assert_eq!(outputs[0].shape, vec![1, 4]); + assert_eq!(outputs[0].data, vec![11.0, 22.0, 33.0, 44.0]); + } + + #[test] + fn load_and_run_relu_model() { + let model_bytes = build_onnx_model("Relu", 1); + let model = OnnxModel::load_from_bytes(&model_bytes).expect("Failed to load model"); + + let input = StdbTensor { + shape: vec![1, 4], + data: vec![-2.0, -1.0, 0.0, 3.0], + }; + + let outputs = model.run(&[input]).expect("Inference failed"); + assert_eq!(outputs.len(), 1); + assert_eq!(outputs[0].shape, vec![1, 4]); + assert_eq!(outputs[0].data, vec![0.0, 0.0, 0.0, 3.0]); + } + + #[test] + fn invalid_model_bytes() { + let result = OnnxModel::load_from_bytes(b"not a valid onnx model"); + assert!(result.is_err()); + } + + #[test] + fn shape_mismatch_errors() { + let model_bytes = build_onnx_model("Relu", 1); + let model = OnnxModel::load_from_bytes(&model_bytes).expect("Failed to load model"); + + // Wrong number of elements for the declared shape. + let bad_input = StdbTensor { + shape: vec![1, 4], + data: vec![1.0, 2.0], // only 2 elements for a 1x4 tensor + }; + + let result = model.run(&[bad_input]); + assert!(result.is_err()); + } +} From 1964f828638eb2228b8c63f6d436944abb8877c7 Mon Sep 17 00:00:00 2001 From: definenoob Date: Sun, 8 Mar 2026 22:01:00 -0400 Subject: [PATCH 5/6] run_multi --- crates/bindings-sys/src/lib.rs | 59 +++++++++++++++++ crates/bindings/src/onnx.rs | 22 +++++++ crates/core/src/host/mod.rs | 2 + crates/core/src/host/onnx.rs | 40 ++++++++++++ .../src/host/wasmtime/wasm_instance_env.rs | 63 +++++++++++++++++++ .../core/src/host/wasmtime/wasmtime_module.rs | 2 + 6 files changed, 188 insertions(+) diff --git a/crates/bindings-sys/src/lib.rs b/crates/bindings-sys/src/lib.rs index 3fbc4799afc..cb54a1273fd 100644 --- a/crates/bindings-sys/src/lib.rs +++ b/crates/bindings-sys/src/lib.rs @@ -901,6 +901,37 @@ pub mod raw { input_len: u32, out: *mut u32, ) -> u16; + + /// Runs ONNX inference on multiple batches of inputs for a single model. + /// + /// `name_ptr[..name_len]` is a UTF-8 model name. + /// `input_ptr[..input_len]` should contain a BSATN-encoded `Vec>`. + /// + /// On success, a [`BytesSource`] is written to `out[0]` containing a BSATN-encoded + /// `Vec>` with the inference outputs, and this function returns 0. + /// + /// # Traps + /// + /// Traps if: + /// - `name_ptr` is NULL or `name_ptr[..name_len]` is not in bounds of WASM memory. + /// - `name_ptr[..name_len]` is not valid UTF-8. + /// - `input_ptr` is NULL or `input_ptr[..input_len]` is not in bounds of WASM memory. + /// - `out` is NULL or `out[..size_of::()]` is not in bounds of WASM memory. + /// + /// # Errors + /// + /// Returns an error: + /// + /// - `ONNX_ERROR` if the model could not be found, loaded, or inference failed. + /// In this case, a [`BytesSource`] containing a BSATN-encoded error message `String` + /// is written to `out[0]`. + pub fn onnx_run_multi( + name_ptr: *const u8, + name_len: u32, + input_ptr: *const u8, + input_len: u32, + out: *mut u32, + ) -> u16; } /// What strategy does the database index use? @@ -1697,4 +1728,32 @@ pub mod onnx { Some(errno) => panic!("{errno}"), } } + + /// Run ONNX inference on multiple batches of inputs for a named model. + /// + /// The host loads and caches the model on first use. + /// `input_bsatn` should be a BSATN-encoded `Vec>`. + /// + /// On success, returns `Ok(bytes_source)` containing BSATN-encoded `Vec>`. + /// On failure, returns `Err(bytes_source)` containing a BSATN-encoded error message. + #[inline] + pub fn run_multi(name: &str, input_bsatn: &[u8]) -> Result { + let mut out = [raw::BytesSource::INVALID; 1]; + + let res = unsafe { + super::raw::onnx_run_multi( + name.as_ptr(), + name.len() as u32, + input_bsatn.as_ptr(), + input_bsatn.len() as u32, + out.as_mut_ptr().cast(), + ) + }; + + match super::Errno::from_code(res) { + None => Ok(out[0]), + Some(errno) if errno == super::Errno::ONNX_ERROR => Err(out[0]), + Some(errno) => panic!("{errno}"), + } + } } diff --git a/crates/bindings/src/onnx.rs b/crates/bindings/src/onnx.rs index 8a1739f904e..666e60e6ba9 100644 --- a/crates/bindings/src/onnx.rs +++ b/crates/bindings/src/onnx.rs @@ -55,6 +55,28 @@ impl OnnxClient { } } } + + /// Run inference on multiple batches of inputs in a single host call. + /// + /// Each element of `batches` is one set of input tensors (one inference invocation). + /// Returns one `Vec` of outputs per batch, in the same order. + /// + /// This is more efficient than calling [`run`](Self::run) in a loop because it + /// crosses the WASM boundary only once for all batches. + pub fn run_multi(&self, model_name: &str, batches: &[Vec]) -> Result>, Error> { + let input_bsatn = bsatn::to_vec(batches).expect("Failed to BSATN-serialize input tensor batches"); + + match spacetimedb_bindings_sys::onnx::run_multi(model_name, &input_bsatn) { + Ok(output_source) => { + let output = read_bytes_source_as::>>(output_source); + Ok(output) + } + Err(err_source) => { + let message = read_bytes_source_as::(err_source); + Err(Error { message }) + } + } + } } /// An error from ONNX model loading or inference. diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index cf2771324d6..bd377f89a8c 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -199,4 +199,6 @@ pub enum AbiCall { #[cfg(feature = "onnx")] OnnxRun, + #[cfg(feature = "onnx")] + OnnxRunMulti, } diff --git a/crates/core/src/host/onnx.rs b/crates/core/src/host/onnx.rs index 69f825737ab..8033bdcb302 100644 --- a/crates/core/src/host/onnx.rs +++ b/crates/core/src/host/onnx.rs @@ -97,6 +97,15 @@ impl OnnxModel { Ok(outputs) } + + /// Run inference on multiple batches of input tensors. + /// + /// Each element of `batches` is one set of input tensors (one inference invocation). + /// Returns one `Vec` of outputs per batch, in the same order. + /// This amortizes the overhead of crossing the WASM boundary for many inferences. + pub fn run_multi(&self, batches: &[Vec]) -> Result>, OnnxError> { + batches.iter().map(|inputs| self.run(inputs)).collect() + } } /// An error from ONNX model loading or inference. @@ -249,6 +258,37 @@ mod tests { assert!(result.is_err()); } + #[test] + fn run_multi_batches() { + let model_bytes = build_onnx_model("Add", 2); + let model = OnnxModel::load_from_bytes(&model_bytes).expect("Failed to load model"); + + let batches = vec![ + vec![ + StdbTensor { shape: vec![1, 4], data: vec![1.0, 2.0, 3.0, 4.0] }, + StdbTensor { shape: vec![1, 4], data: vec![10.0, 20.0, 30.0, 40.0] }, + ], + vec![ + StdbTensor { shape: vec![1, 4], data: vec![5.0, 5.0, 5.0, 5.0] }, + StdbTensor { shape: vec![1, 4], data: vec![1.0, 1.0, 1.0, 1.0] }, + ], + ]; + + let results = model.run_multi(&batches).expect("run_multi failed"); + assert_eq!(results.len(), 2); + assert_eq!(results[0][0].data, vec![11.0, 22.0, 33.0, 44.0]); + assert_eq!(results[1][0].data, vec![6.0, 6.0, 6.0, 6.0]); + } + + #[test] + fn run_multi_empty_batches() { + let model_bytes = build_onnx_model("Relu", 1); + let model = OnnxModel::load_from_bytes(&model_bytes).expect("Failed to load model"); + + let results = model.run_multi(&[]).expect("run_multi on empty batches failed"); + assert!(results.is_empty()); + } + #[test] fn shape_mismatch_errors() { let model_bytes = build_onnx_model("Relu", 1); diff --git a/crates/core/src/host/wasmtime/wasm_instance_env.rs b/crates/core/src/host/wasmtime/wasm_instance_env.rs index ce8c55b6f0a..6625b834cb6 100644 --- a/crates/core/src/host/wasmtime/wasm_instance_env.rs +++ b/crates/core/src/host/wasmtime/wasm_instance_env.rs @@ -1948,6 +1948,69 @@ impl WasmInstanceEnv { }) } + #[cfg(feature = "onnx")] + /// Run ONNX inference on multiple batches of inputs for a single model. + /// + /// `name_ptr[..name_len]` is a UTF-8 model name. + /// `input_ptr[..input_len]` contains BSATN-encoded `Vec>` (one batch per entry). + /// + /// On success, writes a `BytesSource` containing BSATN-encoded `Vec>` to `out` + /// and returns 0. + /// On error, writes a `BytesSource` containing a BSATN-encoded error `String` to `out` + /// and returns `ONNX_ERROR`. + pub fn onnx_run_multi( + caller: Caller<'_, Self>, + name_ptr: WasmPtr, + name_len: u32, + input_ptr: WasmPtr, + input_len: u32, + out: WasmPtr, + ) -> RtResult { + Self::cvt_custom(caller, AbiCall::OnnxRunMulti, |caller| { + let (mem, env) = Self::mem_env(caller); + let name = mem.deref_str(name_ptr, name_len)?.to_owned(); + + // Load and cache the model on first use. + if !env.onnx_models.contains_key(&name) { + match crate::host::onnx::OnnxModel::load_by_name(&name, &env.instance_env) { + Ok(model) => { + env.onnx_models.insert(name.clone(), model); + } + Err(err) => { + let err_msg = bsatn::to_vec(&err.to_string()) + .context("Failed to BSATN-serialize ONNX error")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, err_msg.into())?; + bytes_source.0.write_to(mem, out)?; + return Ok(errno::ONNX_ERROR.get() as u32); + } + } + } + + let model = env.onnx_models.get(&name).unwrap(); + + let input_buf = mem.deref_slice(input_ptr, input_len)?; + let batches: Vec> = + bsatn::from_slice(input_buf).map_err(|err| NodesError::DecodeValue(err))?; + + match model.run_multi(&batches) { + Ok(outputs) => { + let result = bsatn::to_vec(&outputs) + .context("Failed to BSATN-serialize ONNX output tensors")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, result.into())?; + bytes_source.0.write_to(mem, out)?; + Ok(0u32) + } + Err(err) => { + let err_msg = bsatn::to_vec(&err.to_string()) + .context("Failed to BSATN-serialize ONNX error")?; + let bytes_source = WasmInstanceEnv::create_bytes_source(env, err_msg.into())?; + bytes_source.0.write_to(mem, out)?; + Ok(errno::ONNX_ERROR.get() as u32) + } + } + }) + } + pub fn procedure_http_request<'caller>( caller: Caller<'caller, Self>, (request_ptr, request_len, body_ptr, body_len, out): (WasmPtr, u32, WasmPtr, u32, WasmPtr), diff --git a/crates/core/src/host/wasmtime/wasmtime_module.rs b/crates/core/src/host/wasmtime/wasmtime_module.rs index 3b93955c90d..e5d5a5d9809 100644 --- a/crates/core/src/host/wasmtime/wasmtime_module.rs +++ b/crates/core/src/host/wasmtime/wasmtime_module.rs @@ -69,6 +69,8 @@ impl WasmtimeModule { abi_funcs!(link_functions, link_async_functions); #[cfg(feature = "onnx")] linker.func_wrap("spacetime_10.5", "onnx_run", WasmInstanceEnv::onnx_run)?; + #[cfg(feature = "onnx")] + linker.func_wrap("spacetime_10.5", "onnx_run_multi", WasmInstanceEnv::onnx_run_multi)?; Ok(()) } } From 018754a8ad20cd8fbc5a2f3c88b0d416c9ec6a85 Mon Sep 17 00:00:00 2001 From: definenoob Date: Sun, 8 Mar 2026 22:16:01 -0400 Subject: [PATCH 6/6] smoke tests --- crates/core/src/host/onnx.rs | 41 +++++++++ crates/smoketests/tests/smoketests/mod.rs | 1 + crates/smoketests/tests/smoketests/onnx.rs | 96 ++++++++++++++++++++++ crates/standalone/Cargo.toml | 1 + 4 files changed, 139 insertions(+) create mode 100644 crates/smoketests/tests/smoketests/onnx.rs diff --git a/crates/core/src/host/onnx.rs b/crates/core/src/host/onnx.rs index 8033bdcb302..a0336386a57 100644 --- a/crates/core/src/host/onnx.rs +++ b/crates/core/src/host/onnx.rs @@ -289,6 +289,47 @@ mod tests { assert!(results.is_empty()); } + #[test] + fn load_from_file_and_run() { + let dir = tempfile::tempdir().unwrap(); + let model_bytes = build_onnx_model("Relu", 1); + std::fs::write(dir.path().join("test_relu.onnx"), &model_bytes).unwrap(); + + // Simulate what load_by_name does: read from filesystem, then load. + let path = dir.path().join("test_relu.onnx"); + let bytes = std::fs::read(&path).unwrap(); + let model = OnnxModel::load_from_bytes(&bytes).expect("Failed to load model from file"); + + let input = StdbTensor { + shape: vec![1, 4], + data: vec![-1.0, 0.0, 1.0, 2.0], + }; + let outputs = model.run(&[input]).unwrap(); + assert_eq!(outputs[0].data, vec![0.0, 0.0, 1.0, 2.0]); + } + + #[test] + fn load_by_name_rejects_path_traversal() { + // We can't construct a full InstanceEnv in unit tests, but we can verify + // the name validation logic directly. + let bad_names = ["", "../etc/passwd", "foo/bar", "foo\\bar", ".."]; + for name in bad_names { + assert!( + name.contains('/') || name.contains('\\') || name.contains("..") || name.is_empty(), + "Expected {name:?} to be rejected by validation" + ); + } + + // Valid names pass validation. + let good_names = ["bot_brain", "my-model", "model.v2"]; + for name in good_names { + assert!( + !name.contains('/') && !name.contains('\\') && !name.contains("..") && !name.is_empty(), + "Expected {name:?} to pass validation" + ); + } + } + #[test] fn shape_mismatch_errors() { let model_bytes = build_onnx_model("Relu", 1); diff --git a/crates/smoketests/tests/smoketests/mod.rs b/crates/smoketests/tests/smoketests/mod.rs index f5053652dd3..2f43987d34f 100644 --- a/crates/smoketests/tests/smoketests/mod.rs +++ b/crates/smoketests/tests/smoketests/mod.rs @@ -23,6 +23,7 @@ mod logs_level_filter; mod module_nested_op; mod modules; mod namespaces; +mod onnx; mod new_user_flow; mod panic; mod permissions; diff --git a/crates/smoketests/tests/smoketests/onnx.rs b/crates/smoketests/tests/smoketests/onnx.rs new file mode 100644 index 00000000000..a56696b1b0d --- /dev/null +++ b/crates/smoketests/tests/smoketests/onnx.rs @@ -0,0 +1,96 @@ +use spacetimedb_smoketests::Smoketest; +use std::fs; + +/// Minimal ONNX "Add" model (2 inputs → 1 output, shape [1,4], f32). +/// Generated from `tract_onnx::pb` protobuf types with opset 13. +const ADD_MODEL_ONNX: &[u8] = &[ + 8, 7, 18, 16, 115, 112, 97, 99, 101, 116, 105, 109, 101, 100, 98, 45, 116, 101, 115, 116, + 58, 133, 1, 10, 39, 10, 7, 105, 110, 112, 117, 116, 95, 48, 10, 7, 105, 110, 112, 117, 116, + 95, 49, 18, 6, 111, 117, 116, 112, 117, 116, 26, 6, 110, 111, 100, 101, 95, 48, 34, 3, 65, + 100, 100, 18, 10, 116, 101, 115, 116, 95, 103, 114, 97, 112, 104, 90, 25, 10, 7, 105, 110, + 112, 117, 116, 95, 48, 18, 14, 10, 12, 8, 1, 18, 8, 10, 2, 8, 1, 10, 2, 8, 4, 90, 25, 10, + 7, 105, 110, 112, 117, 116, 95, 49, 18, 14, 10, 12, 8, 1, 18, 8, 10, 2, 8, 1, 10, 2, 8, 4, + 98, 24, 10, 6, 111, 117, 116, 112, 117, 116, 18, 14, 10, 12, 8, 1, 18, 8, 10, 2, 8, 1, 10, + 2, 8, 4, 66, 2, 16, 13, +]; + +const ONNX_MODULE: &str = r#" +use spacetimedb::{log, ReducerContext, onnx::Tensor}; + +#[spacetimedb::reducer] +pub fn run_add(ctx: &ReducerContext) { + let a = vec![Tensor { shape: vec![1, 4], data: vec![1.0, 2.0, 3.0, 4.0] }]; + let b = vec![Tensor { shape: vec![1, 4], data: vec![10.0, 20.0, 30.0, 40.0] }]; + + let inputs = vec![a[0].clone(), b[0].clone()]; + let output = ctx.onnx.run("test_add", &inputs).expect("run failed"); + log::info!("add_result: {:?}", output[0].data); +} + +#[spacetimedb::reducer] +pub fn run_add_multi(ctx: &ReducerContext) { + let batches = vec![ + vec![ + Tensor { shape: vec![1, 4], data: vec![1.0, 2.0, 3.0, 4.0] }, + Tensor { shape: vec![1, 4], data: vec![10.0, 20.0, 30.0, 40.0] }, + ], + vec![ + Tensor { shape: vec![1, 4], data: vec![5.0, 5.0, 5.0, 5.0] }, + Tensor { shape: vec![1, 4], data: vec![1.0, 1.0, 1.0, 1.0] }, + ], + ]; + let results = ctx.onnx.run_multi("test_add", &batches).expect("run_multi failed"); + log::info!("multi_result_0: {:?}", results[0][0].data); + log::info!("multi_result_1: {:?}", results[1][0].data); +} +"#; + +/// Place the test ONNX model in the server's models directory. +fn setup_model(test: &Smoketest) { + let guard = test.guard.as_ref().expect("ONNX tests require a local server"); + let models_dir = guard.data_dir.join("models"); + fs::create_dir_all(&models_dir).expect("Failed to create models directory"); + fs::write(models_dir.join("test_add.onnx"), ADD_MODEL_ONNX).expect("Failed to write test model"); +} + +/// Test single ONNX inference from a WASM module reducer. +#[test] +fn test_onnx_run() { + let test = Smoketest::builder() + .module_code(ONNX_MODULE) + .bindings_features(&["unstable", "onnx"]) + .build(); + + setup_model(&test); + + test.call("run_add", &[]).unwrap(); + + let logs = test.logs(10).unwrap(); + assert!( + logs.iter().any(|l| l.contains("[11.0, 22.0, 33.0, 44.0]")), + "Expected add result in logs, got: {logs:?}" + ); +} + +/// Test batched ONNX inference (run_multi) from a WASM module reducer. +#[test] +fn test_onnx_run_multi() { + let test = Smoketest::builder() + .module_code(ONNX_MODULE) + .bindings_features(&["unstable", "onnx"]) + .build(); + + setup_model(&test); + + test.call("run_add_multi", &[]).unwrap(); + + let logs = test.logs(10).unwrap(); + assert!( + logs.iter().any(|l| l.contains("[11.0, 22.0, 33.0, 44.0]")), + "Expected first batch result in logs, got: {logs:?}" + ); + assert!( + logs.iter().any(|l| l.contains("[6.0, 6.0, 6.0, 6.0]")), + "Expected second batch result in logs, got: {logs:?}" + ); +} diff --git a/crates/standalone/Cargo.toml b/crates/standalone/Cargo.toml index 0ce65a57ed0..e47335892e3 100644 --- a/crates/standalone/Cargo.toml +++ b/crates/standalone/Cargo.toml @@ -24,6 +24,7 @@ perfmap = ["spacetimedb-core/perfmap"] # Disables core pinning no-core-pinning = ["spacetimedb-core/no-core-pinning"] no-job-core-pinning = ["spacetimedb-core/no-job-core-pinning"] +onnx = ["spacetimedb-core/onnx"] [dependencies] spacetimedb-client-api-messages.workspace = true