diff --git a/Cargo.lock b/Cargo.lock index a2d520b..d6c3ef7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,6 +31,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -126,10 +132,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "base64ct" -version = "1.8.3" +name = "bindgen_cuda" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +checksum = "282be55fb326843bb67cccceeeaf21c961ef303f60018f9a2ab69494dad8eaf9" +dependencies = [ + "glob", + "num_cpus", + "rayon", +] [[package]] name = "bit-set" @@ -158,6 +169,12 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -167,6 +184,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdeb9d870516001442e364c5220d3574d2da8dc765554b4a617230d33fa58ef5" +dependencies = [ + "objc2", +] + [[package]] name = "bstr" version = "1.12.1" @@ -183,6 +209,26 @@ version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "byteorder" version = "1.5.0" @@ -195,6 +241,105 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "candle-core" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c15b675b80d994b2eadb20a4bbe434eabeb454eac3ee5e2b4cf6f147ee9be091" +dependencies = [ + "byteorder", + "candle-kernels", + "candle-metal-kernels", + "candle-ug", + "cudarc 0.19.4", + "float8 0.6.1", + "gemm 0.19.0", + "half", + "libm", + "memmap2", + "num-traits", + "num_cpus", + "objc2-foundation", + "objc2-metal", + "rand", + "rand_distr", + "rayon", + "safetensors 0.7.0", + "thiserror 2.0.18", + "yoke 0.8.1", + "zip", +] + +[[package]] +name = "candle-kernels" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8455f84bd810047c7c41216683c1020c915a9f8a740b3b0eabdd4fb2fbaa660" +dependencies = [ + "bindgen_cuda", +] + +[[package]] +name = "candle-metal-kernels" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fdfe9d06de16ce49961e49084e5b79a75a9bdf157246e7c7b6328e87a7aa25d" +dependencies = [ + "half", + "objc2", + "objc2-foundation", + "objc2-metal", + "once_cell", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "candle-nn" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3045fa9e7aef8567d209a27d56b692f60b96f4d0569f4c3011f8ca6715c65e03" +dependencies = [ + "candle-core", + "half", + "libc", + "num-traits", + "rayon", + "safetensors 0.7.0", + "serde", + "thiserror 2.0.18", +] + +[[package]] +name = "candle-transformers" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b538ec4aa807c416a2ddd3621044888f188827862e2a6fcacba4738e89795d01" +dependencies = [ + "byteorder", + "candle-core", + "candle-nn", + "fancy-regex 0.17.0", + "num-traits", + "rand", + "rayon", + "serde", + "serde_json", + "serde_plain", + "tracing", +] + +[[package]] +name = "candle-ug" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c22d62be69068bf58987a45f690612739d8d2ea1bf508c1b87dc6815a019575d" +dependencies = [ + "ug", + "ug-cuda", + "ug-metal", +] + [[package]] name = "castaway" version = "0.2.4" @@ -310,9 +455,9 @@ dependencies = [ [[package]] name = "core-foundation" -version = "0.10.1" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" dependencies = [ "core-foundation-sys", "libc", @@ -324,6 +469,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -367,6 +523,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -377,6 +539,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "cudarc" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf99ab37ee7072d64d906aa2dada9a3422f1d975cdf8c8055a573bc84897ed8" +dependencies = [ + "half", + "libloading 0.8.9", +] + +[[package]] +name = "cudarc" +version = "0.19.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f071cd6a7b5d51607df76aa2d426aaabc7a74bc6bdb885b8afa63a880572ad9b" +dependencies = [ + "float8 0.7.0", + "half", + "libloading 0.9.0", +] + [[package]] name = "darling" version = "0.20.11" @@ -455,16 +638,6 @@ dependencies = [ "serde", ] -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "pem-rfc7468", - "zeroize", -] - [[package]] name = "deranged" version = "0.5.8" @@ -536,6 +709,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "dispatch2" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0e367e4e7da84520dedcac1901e4da967309406d1e51017ae1abfb97adbd38" +dependencies = [ + "bitflags 2.11.0", + "objc2", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -553,6 +736,22 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "dyn-stack" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8" +dependencies = [ + "bytemuck", + "dyn-stack-macros", +] + +[[package]] +name = "dyn-stack-macros" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d926b4d407d372f141f93bb444696142c29d32962ccbd3531117cf3aa0bfa9" + [[package]] name = "either" version = "1.15.0" @@ -567,17 +766,18 @@ checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "engraph" -version = "0.7.0" +version = "1.0.0" dependencies = [ "anyhow", + "candle-core", + "candle-nn", + "candle-transformers", "clap", "dirs", "ignore", "indicatif", - "ndarray", "notify", "notify-debouncer-full", - "ort", "rayon", "rmcp", "rusqlite", @@ -593,10 +793,22 @@ dependencies = [ "toml", "tracing", "tracing-subscriber", - "ureq 2.12.1", + "ureq", "zerocopy 0.7.35", ] +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -642,6 +854,17 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "fancy-regex" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72cf461f865c862bb7dc573f643dd6a2b6842f7c30b07882b56bd148cc2761b8" +dependencies = [ + "bit-set", + "regex-automata", + "regex-syntax", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -684,6 +907,28 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float8" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719a903cc23e4a89e87962c2a80fdb45cdaad0983a89bd150bb57b4c8571a7d5" +dependencies = [ + "cudarc 0.19.4", + "half", + "num-traits", + "rand", + "rand_distr", +] + +[[package]] +name = "float8" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2d1f04709a8ac06e8e8042875a3c466cc4832d3c1a18dbcb9dba3c6e83046bc" +dependencies = [ + "half", +] + [[package]] name = "fnv" version = "1.0.7" @@ -696,20 +941,38 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" -version = "0.3.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ + "foreign-types-macros", "foreign-types-shared", ] +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "foreign-types-shared" -version = "0.1.1" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" [[package]] name = "form_urlencoded" @@ -817,6 +1080,244 @@ dependencies = [ "slab", ] +[[package]] +name = "gemm" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" +dependencies = [ + "dyn-stack", + "gemm-c32 0.18.2", + "gemm-c64 0.18.2", + "gemm-common 0.18.2", + "gemm-f16 0.18.2", + "gemm-f32 0.18.2", + "gemm-f64 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa0673db364b12263d103b68337a68fbecc541d6f6b61ba72fe438654709eacb" +dependencies = [ + "dyn-stack", + "gemm-c32 0.19.0", + "gemm-c64 0.19.0", + "gemm-common 0.19.0", + "gemm-f16 0.19.0", + "gemm-f32 0.19.0", + "gemm-f64 0.19.0", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" +dependencies = [ + "dyn-stack", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "086936dbdcb99e37aad81d320f98f670e53c1e55a98bee70573e83f95beb128c" +dependencies = [ + "dyn-stack", + "gemm-common 0.19.0", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" +dependencies = [ + "dyn-stack", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20c8aeeeec425959bda4d9827664029ba1501a90a0d1e6228e48bef741db3a3f" +dependencies = [ + "dyn-stack", + "gemm-common 0.19.0", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "libm", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp 0.21.5", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-common" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88027625910cc9b1085aaaa1c4bc46bb3a36aad323452b33c25b5e4e7c8e2a3e" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "libm", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp 0.22.2", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-f16" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" +dependencies = [ + "dyn-stack", + "gemm-common 0.18.2", + "gemm-f32 0.18.2", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f16" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3df7a55202e6cd6739d82ae3399c8e0c7e1402859b30e4cb780e61525d9486e" +dependencies = [ + "dyn-stack", + "gemm-common 0.19.0", + "gemm-f32 0.19.0", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" +dependencies = [ + "dyn-stack", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0b8c9da1fbec6e3e3ab2ce6bc259ef18eb5f6f0d3e4edf54b75f9fd41a81c" +dependencies = [ + "dyn-stack", + "gemm-common 0.19.0", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" +dependencies = [ + "dyn-stack", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "056131e8f2a521bfab322f804ccd652520c79700d81209e9d9275bbdecaadc6a" +dependencies = [ + "dyn-stack", + "gemm-common 0.19.0", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -863,6 +1364,12 @@ dependencies = [ "wasip3", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "globset" version = "0.4.18" @@ -876,6 +1383,21 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", + "zerocopy 0.8.42", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -891,7 +1413,7 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -899,6 +1421,13 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", + "serde", + "serde_core", +] [[package]] name = "hashlink" @@ -916,26 +1445,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] -name = "hmac-sha256" -version = "1.1.14" +name = "hermit-abi" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec9d92d097f4749b64e8cc33d924d9f40a2d4eb91402b458014b781f5733d60f" - -[[package]] -name = "http" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" -dependencies = [ - "bytes", - "itoa", -] - -[[package]] -name = "httparse" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" [[package]] name = "iana-time-zone" @@ -969,7 +1482,7 @@ checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" dependencies = [ "displaydoc", "potential_utf", - "yoke", + "yoke 0.8.1", "zerofrom", "zerovec", ] @@ -1036,7 +1549,7 @@ dependencies = [ "displaydoc", "icu_locale_core", "writeable", - "yoke", + "yoke 0.8.1", "zerofrom", "zerotrie", "zerovec", @@ -1214,6 +1727,32 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + [[package]] name = "libredox" version = "0.1.14" @@ -1255,12 +1794,6 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" -[[package]] -name = "lzma-rust2" -version = "0.15.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" - [[package]] name = "macro_rules_attribute" version = "0.2.2" @@ -1277,6 +1810,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matchers" version = "0.2.0" @@ -1287,20 +1829,35 @@ dependencies = [ ] [[package]] -name = "matrixmultiply" -version = "0.3.10" +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "memmap2" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" dependencies = [ - "autocfg", - "rawpointer", + "libc", + "stable_deref_trait", ] [[package]] -name = "memchr" -version = "2.8.0" +name = "metal" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" +dependencies = [ + "bitflags 2.11.0", + "block", + "core-graphics-types", + "foreign-types", + "log", + "objc", + "paste", +] [[package]] name = "minimal-lexical" @@ -1352,38 +1909,6 @@ dependencies = [ "syn", ] -[[package]] -name = "native-tls" -version = "0.2.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - -[[package]] -name = "ndarray" -version = "0.17.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" -dependencies = [ - "matrixmultiply", - "num-complex", - "num-integer", - "num-traits", - "portable-atomic", - "portable-atomic-util", - "rawpointer", -] - [[package]] name = "nom" version = "7.1.3" @@ -1444,12 +1969,37 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-complex" version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ + "bytemuck", "num-traits", ] @@ -1468,6 +2018,28 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1475,99 +2047,104 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] -name = "number_prefix" -version = "0.4.0" +name = "num_cpus" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] [[package]] -name = "once_cell" -version = "1.21.4" +name = "number_prefix" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] -name = "once_cell_polyfill" -version = "1.70.2" +name = "objc" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] [[package]] -name = "openssl" -version = "0.10.76" +name = "objc2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" +checksum = "3a12a8ed07aefc768292f076dc3ac8c48f3781c8f2d5851dd3d98950e8c5a89f" dependencies = [ - "bitflags 2.11.0", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", + "objc2-encode", ] [[package]] -name = "openssl-macros" -version = "0.1.1" +name = "objc2-core-foundation" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ - "proc-macro2", - "quote", - "syn", + "bitflags 2.11.0", + "dispatch2", + "objc2", ] [[package]] -name = "openssl-probe" -version = "0.2.1" +name = "objc2-encode" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] -name = "openssl-sys" -version = "0.9.112" +name = "objc2-foundation" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" dependencies = [ - "cc", + "bitflags 2.11.0", + "block2", "libc", - "pkg-config", - "vcpkg", + "objc2", + "objc2-core-foundation", ] [[package]] -name = "option-ext" -version = "0.2.0" +name = "objc2-metal" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +checksum = "a0125f776a10d00af4152d74616409f0d4a2053a6f57fa5b7d6aa2854ac04794" +dependencies = [ + "bitflags 2.11.0", + "block2", + "dispatch2", + "objc2", + "objc2-core-foundation", + "objc2-foundation", +] [[package]] -name = "ort" -version = "2.0.0-rc.12" +name = "once_cell" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7de3af33d24a745ffb8fab904b13478438d1cd52868e6f17735ef6e1f8bf133" -dependencies = [ - "ndarray", - "ort-sys", - "smallvec", - "tracing", - "ureq 3.2.0", -] +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] -name = "ort-sys" -version = "2.0.0-rc.12" +name = "once_cell_polyfill" +version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7b497d21a8b6fbb4b5a544f8fadb77e801a09ae0add9e411d31c6f89e3c1e90" -dependencies = [ - "hmac-sha256", - "lzma-rust2", - "ureq 3.2.0", -] +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "paste" @@ -1581,15 +2158,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - [[package]] name = "percent-encoding" version = "2.3.2" @@ -1620,15 +2188,6 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" -[[package]] -name = "portable-atomic-util" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" -dependencies = [ - "portable-atomic", -] - [[package]] name = "potential_utf" version = "0.1.4" @@ -1672,6 +2231,43 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "pulp" +version = "0.21.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907" +dependencies = [ + "bytemuck", + "cfg-if", + "libm", + "num-complex", + "reborrow", + "version_check", +] + +[[package]] +name = "pulp" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e205bb30d5b916c55e584c22201771bcf2bad9aabd5d4127f38387140c38632" +dependencies = [ + "bytemuck", + "cfg-if", + "libm", + "num-complex", + "paste", + "pulp-wasm-simd-flag", + "raw-cpuid", + "reborrow", + "version_check", +] + +[[package]] +name = "pulp-wasm-simd-flag" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40e24eee682d89fb193496edf918a7f407d30175b2e785fe057e4392dfd182e0" + [[package]] name = "quote" version = "1.0.45" @@ -1723,10 +2319,23 @@ dependencies = [ ] [[package]] -name = "rawpointer" -version = "0.2.1" +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "11.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.11.0", +] [[package]] name = "rayon" @@ -1759,6 +2368,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "redox_syscall" version = "0.7.3" @@ -1952,21 +2567,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" [[package]] -name = "same-file" -version = "1.0.6" +name = "safetensors" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" dependencies = [ - "winapi-util", + "serde", + "serde_json", ] [[package]] -name = "schannel" -version = "0.1.29" +name = "safetensors" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" dependencies = [ - "windows-sys 0.61.2", + "hashbrown 0.16.1", + "serde", + "serde_json", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", ] [[package]] @@ -1995,35 +2622,18 @@ dependencies = [ "syn", ] -[[package]] -name = "security-framework" -version = "3.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" -dependencies = [ - "bitflags 2.11.0", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "semver" version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +[[package]] +name = "seq-macro" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" + [[package]] name = "serde" version = "1.0.228" @@ -2078,6 +2688,15 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.9" @@ -2131,17 +2750,6 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -[[package]] -name = "socks" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" -dependencies = [ - "byteorder", - "libc", - "winapi", -] - [[package]] name = "spm_precompiled" version = "0.1.4" @@ -2209,6 +2817,20 @@ dependencies = [ "syn", ] +[[package]] +name = "sysctl" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" +dependencies = [ + "bitflags 2.11.0", + "byteorder", + "enum-as-inner", + "libc", + "thiserror 1.0.69", + "walkdir", +] + [[package]] name = "tempfile" version = "3.27.0" @@ -2312,7 +2934,7 @@ dependencies = [ "dary_heap", "derive_builder", "esaxx-rs", - "fancy-regex", + "fancy-regex 0.14.0", "getrandom 0.3.4", "itertools", "log", @@ -2470,12 +3092,66 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "typed-path" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e28f89b80c87b8fb0cf04ab448d5dd0dd0ade2f8891bae878de66a75a28600e" + [[package]] name = "typenum" version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "ug" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76b761acf8af3494640d826a8609e2265e19778fb43306c7f15379c78c9b05b0" +dependencies = [ + "gemm 0.18.2", + "half", + "libloading 0.8.9", + "memmap2", + "num", + "num-traits", + "num_cpus", + "rayon", + "safetensors 0.4.5", + "serde", + "thiserror 1.0.69", + "tracing", + "yoke 0.7.5", +] + +[[package]] +name = "ug-cuda" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f0a1fa748f26166778c33b8498255ebb7c6bffb472bcc0a72839e07ebb1d9b5" +dependencies = [ + "cudarc 0.17.8", + "half", + "serde", + "thiserror 1.0.69", + "ug", +] + +[[package]] +name = "ug-metal" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7adf545a99a086d362efc739e7cf4317c18cbeda22706000fd434d70ea3d95" +dependencies = [ + "half", + "metal", + "objc", + "serde", + "thiserror 1.0.69", + "ug", +] + [[package]] name = "unicode-ident" version = "1.0.24" @@ -2537,36 +3213,6 @@ dependencies = [ "webpki-roots 0.26.11", ] -[[package]] -name = "ureq" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" -dependencies = [ - "base64 0.22.1", - "der", - "log", - "native-tls", - "percent-encoding", - "rustls-pki-types", - "socks", - "ureq-proto", - "utf-8", - "webpki-root-certs", -] - -[[package]] -name = "ureq-proto" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" -dependencies = [ - "base64 0.22.1", - "http", - "httparse", - "log", -] - [[package]] name = "url" version = "2.5.8" @@ -2579,12 +3225,6 @@ dependencies = [ "serde", ] -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -2738,15 +3378,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "webpki-root-certs" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "webpki-roots" version = "0.26.11" @@ -2765,22 +3396,6 @@ dependencies = [ "rustls-pki-types", ] -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - [[package]] name = "winapi-util" version = "0.1.11" @@ -2790,12 +3405,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-core" version = "0.62.2" @@ -3189,6 +3798,18 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive 0.7.5", + "zerofrom", +] + [[package]] name = "yoke" version = "0.8.1" @@ -3196,10 +3817,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" dependencies = [ "stable_deref_trait", - "yoke-derive", + "yoke-derive 0.8.1", "zerofrom", ] +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "yoke-derive" version = "0.8.1" @@ -3287,7 +3920,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" dependencies = [ "displaydoc", - "yoke", + "yoke 0.8.1", "zerofrom", ] @@ -3297,7 +3930,7 @@ version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" dependencies = [ - "yoke", + "yoke 0.8.1", "zerofrom", "zerovec-derive", ] @@ -3313,6 +3946,18 @@ dependencies = [ "syn", ] +[[package]] +name = "zip" +version = "7.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42e33efc22a0650c311c2ef19115ce232583abbe80850bc8b66509ebef02de0" +dependencies = [ + "crc32fast", + "indexmap", + "memchr", + "typed-path", +] + [[package]] name = "zmij" version = "1.0.21" diff --git a/Cargo.toml b/Cargo.toml index a93aaa2..93a2bc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "engraph" -version = "0.7.0" +version = "1.0.0" edition = "2024" description = "Local knowledge graph for AI agents. Hybrid search + MCP server for Obsidian vaults." license = "MIT" @@ -20,12 +20,10 @@ anyhow = "1" rusqlite = { version = "0.32", features = ["bundled"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -ort = { version = "2.0.0-rc.12", features = ["ndarray"] } tokenizers = { version = "0.22", default-features = false, features = ["fancy-regex"] } sha2 = "0.10" ureq = "2.12" indicatif = "0.17" -ndarray = "0.17" sqlite-vec = "0.1.8-alpha.1" zerocopy = { version = "0.7", features = ["derive"] } rayon = "1" @@ -36,6 +34,14 @@ rmcp = { version = "1.2", features = ["transport-io"] } tokio = { version = "1", features = ["macros", "rt-multi-thread"] } notify = "7.0" notify-debouncer-full = "0.4" +candle-core = "0.9" +candle-nn = "0.9" +candle-transformers = "0.9" + +[features] +default = [] +metal = ["candle-core/metal"] +cuda = ["candle-core/cuda"] [dev-dependencies] tempfile = "3" diff --git a/assets/demo-search.tape b/assets/demo-search.tape new file mode 100644 index 0000000..6a66bb6 --- /dev/null +++ b/assets/demo-search.tape @@ -0,0 +1,42 @@ +Output assets/demo-search.gif +Set Shell zsh +Set FontSize 14 +Set Width 1000 +Set Height 600 +Set Padding 20 +Set Theme "Catppuccin Mocha" +Set TypingSpeed 40ms + +Type "# Index an Obsidian vault" +Enter +Sleep 500ms + +Type "engraph index ~/vault" +Enter +Sleep 3s + +Type "" +Enter +Sleep 300ms + +Type "# Search across notes — 3-lane hybrid (semantic + keyword + graph)" +Enter +Sleep 500ms + +Type "engraph search 'how does authentication work' --explain" +Enter +Sleep 4s + +Type "" +Enter +Sleep 300ms + +Type "# Get rich context for an AI agent" +Enter +Sleep 500ms + +Type "engraph context who 'Steve Barbera'" +Enter +Sleep 3s + +Sleep 2s diff --git a/src/config.rs b/src/config.rs index 4ceaf83..676cecf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,9 +1,21 @@ use anyhow::{Context, Result}; -use serde::Deserialize; -use std::path::PathBuf; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; + +/// Model override configuration. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(default)] +pub struct ModelConfig { + /// Override embedding model URI (e.g., "hf:repo/file.gguf"). + pub embed: Option, + /// Override reranker model URI. + pub rerank: Option, + /// Override expansion/orchestrator model URI. + pub expand: Option, +} /// Application configuration, loaded from `~/.engraph/config.toml` with CLI overrides. -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct Config { /// Path to the Obsidian vault to index. @@ -14,6 +26,10 @@ pub struct Config { pub exclude: Vec, /// Number of files to process per embedding batch. pub batch_size: usize, + /// Whether intelligence features are enabled. None = not yet configured. + pub intelligence: Option, + /// Model override URIs. + pub models: ModelConfig, } impl Default for Config { @@ -23,6 +39,8 @@ impl Default for Config { top_n: 5, exclude: vec![".obsidian/".to_string()], batch_size: 64, + intelligence: None, + models: ModelConfig::default(), } } } @@ -68,6 +86,34 @@ impl Config { let dir = Self::data_dir()?; crate::profile::load_vault_toml(&dir) } + + /// Whether intelligence is enabled (defaults to false if not configured). + pub fn intelligence_enabled(&self) -> bool { + self.intelligence.unwrap_or(false) + } + + /// Save config to a specific path. + pub fn save_to(&self, path: &Path) -> Result<()> { + let content = toml::to_string_pretty(self).context("serializing config")?; + std::fs::write(path, content).with_context(|| format!("writing {}", path.display()))?; + Ok(()) + } + + /// Load config from a specific path. + pub fn load_from(path: &Path) -> Result { + let contents = + std::fs::read_to_string(path).with_context(|| format!("reading {}", path.display()))?; + let config: Config = + toml::from_str(&contents).with_context(|| format!("parsing {}", path.display()))?; + Ok(config) + } + + /// Save to the default config path (`~/.engraph/config.toml`). + pub fn save(&self) -> Result<()> { + let path = Self::data_dir()?.join("config.toml"); + std::fs::create_dir_all(path.parent().unwrap())?; + self.save_to(&path) + } } #[cfg(test)] @@ -138,4 +184,54 @@ batch_size = 128 let cfg = Config::load().unwrap(); assert_eq!(cfg.batch_size, 64); } + + #[test] + fn parse_intelligence_config() { + let toml_str = r#" +intelligence = true + +[models] +embed = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf" +rerank = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf" +"#; + let cfg: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(cfg.intelligence, Some(true)); + assert!(cfg.models.embed.is_some()); + assert!(cfg.models.rerank.is_some()); + assert!(cfg.models.expand.is_none()); + } + + #[test] + fn intelligence_defaults_to_none() { + let cfg = Config::default(); + assert!(cfg.intelligence.is_none()); + assert!(cfg.models.embed.is_none()); + } + + #[test] + fn intelligence_false_disables_features() { + let toml_str = r#"intelligence = false"#; + let cfg: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(cfg.intelligence, Some(false)); + assert!(!cfg.intelligence_enabled()); + } + + #[test] + fn test_config_roundtrip_with_intelligence() { + let dir = tempfile::tempdir().unwrap(); + let config_path = dir.path().join("config.toml"); + + let mut cfg = Config::default(); + cfg.intelligence = Some(true); + cfg.models.embed = Some("hf:custom/model/embed.gguf".into()); + + cfg.save_to(&config_path).unwrap(); + + let loaded = Config::load_from(&config_path).unwrap(); + assert_eq!(loaded.intelligence, Some(true)); + assert_eq!( + loaded.models.embed, + Some("hf:custom/model/embed.gguf".into()) + ); + } } diff --git a/src/context.rs b/src/context.rs index 1dfcda8..ab17024 100644 --- a/src/context.rs +++ b/src/context.rs @@ -633,7 +633,7 @@ pub fn context_topic_with_search( params: &ContextParams, topic: &str, max_chars: usize, - embedder: &mut crate::embedder::Embedder, + embedder: &mut impl crate::llm::EmbedModel, ) -> Result { let search_output = crate::search::search_internal(topic, 5, params.store, embedder)?; context_topic_from_results(params, topic, &search_output.results, max_chars) diff --git a/src/embedder.rs b/src/embedder.rs deleted file mode 100644 index 5770904..0000000 --- a/src/embedder.rs +++ /dev/null @@ -1,339 +0,0 @@ -use std::io::Read; -use std::path::Path; - -use anyhow::{Context, Result, bail}; -use indicatif::{ProgressBar, ProgressStyle}; -use ndarray::Array2; -use ort::session::Session; -use ort::value::Tensor; -use sha2::{Digest, Sha256}; -use tokenizers::Tokenizer; -use tracing::info; - -const MODEL_URL: &str = - "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx"; -const TOKENIZER_URL: &str = - "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer.json"; -/// SHA-256 of the ONNX model file. Set to empty string to skip verification -/// until we can compute the real hash from a download. -const MODEL_SHA256: &str = "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452"; -pub const EMBEDDING_DIM: usize = 384; - -pub struct Embedder { - session: Session, - tokenizer: Tokenizer, -} - -impl Embedder { - /// Create a new Embedder, downloading the model and tokenizer into - /// `models_dir` if they are not already present. - pub fn new(models_dir: &Path) -> Result { - std::fs::create_dir_all(models_dir) - .with_context(|| format!("creating models dir {}", models_dir.display()))?; - - let model_path = models_dir.join("model.onnx"); - let tokenizer_path = models_dir.join("tokenizer.json"); - - // Download model if missing. - if !model_path.exists() { - download_file(MODEL_URL, &model_path, Some(MODEL_SHA256))?; - } - - // Download tokenizer if missing. - if !tokenizer_path.exists() { - download_file(TOKENIZER_URL, &tokenizer_path, None)?; - } - - // Verify model hash. - verify_sha256(&model_path, MODEL_SHA256)?; - - let session = Session::builder() - .with_context(|| "creating ONNX session builder")? - .commit_from_file(&model_path) - .with_context(|| format!("loading ONNX model from {}", model_path.display()))?; - - let tokenizer = Tokenizer::from_file(&tokenizer_path) - .map_err(|e| anyhow::anyhow!("loading tokenizer: {e}"))?; - - info!("embedder loaded from {}", models_dir.display()); - - Ok(Self { session, tokenizer }) - } - - /// Embed a batch of texts, returning one L2-normalised vector per text. - pub fn embed_batch(&mut self, texts: &[&str]) -> Result>> { - if texts.is_empty() { - return Ok(Vec::new()); - } - - let encodings = self - .tokenizer - .encode_batch(texts.to_vec(), true) - .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?; - - let batch_size = encodings.len(); - let max_len = encodings - .iter() - .map(|e| e.get_ids().len()) - .max() - .unwrap_or(0); - - // Build padded input arrays. - let mut input_ids_vec = vec![0i64; batch_size * max_len]; - let mut attention_mask_vec = vec![0i64; batch_size * max_len]; - let mut token_type_ids_vec = vec![0i64; batch_size * max_len]; - - for (i, enc) in encodings.iter().enumerate() { - let ids = enc.get_ids(); - let mask = enc.get_attention_mask(); - let type_ids = enc.get_type_ids(); - for (j, (&id, &m)) in ids.iter().zip(mask.iter()).enumerate() { - input_ids_vec[i * max_len + j] = id as i64; - attention_mask_vec[i * max_len + j] = m as i64; - if j < type_ids.len() { - token_type_ids_vec[i * max_len + j] = type_ids[j] as i64; - } - } - } - - let input_ids = Array2::from_shape_vec((batch_size, max_len), input_ids_vec)?; - let attention_mask = Array2::from_shape_vec((batch_size, max_len), attention_mask_vec)?; - let token_type_ids = Array2::from_shape_vec((batch_size, max_len), token_type_ids_vec)?; - - let input_ids_tensor = Tensor::from_array(input_ids)?; - let attention_mask_tensor = Tensor::from_array(attention_mask.clone())?; - let token_type_ids_tensor = Tensor::from_array(token_type_ids)?; - - let outputs = self.session.run(ort::inputs![ - "input_ids" => input_ids_tensor, - "attention_mask" => attention_mask_tensor, - "token_type_ids" => token_type_ids_tensor, - ])?; - - // The model outputs token_embeddings of shape (batch, seq_len, 384). - // We need mean pooling with the attention mask. - let token_embeddings = outputs[0].try_extract_array::()?; - let token_embeddings = token_embeddings.to_owned(); // into owned array - - // Mean pooling: sum embeddings where attention_mask == 1, divide by count. - let mut results = Vec::with_capacity(batch_size); - for i in 0..batch_size { - let mut sum = vec![0f32; EMBEDDING_DIM]; - let mut count = 0f32; - for j in 0..max_len { - if attention_mask[[i, j]] == 1 { - count += 1.0; - for k in 0..EMBEDDING_DIM { - sum[k] += token_embeddings[[i, j, k]]; - } - } - } - if count > 0.0 { - for v in &mut sum { - *v /= count; - } - } - let normalized = normalize_vector(&sum); - results.push(normalized); - } - - Ok(results) - } - - /// Embed a single text. - pub fn embed_one(&mut self, text: &str) -> Result> { - let mut batch = self.embed_batch(&[text])?; - batch - .pop() - .ok_or_else(|| anyhow::anyhow!("empty result from embed_batch")) - } - - /// Return the number of tokens in a text string. - pub fn token_count(&self, text: &str) -> usize { - self.tokenizer - .encode(text, false) - .map(|e| e.get_ids().len()) - .unwrap_or(0) - } -} - -impl crate::model::ModelBackend for Embedder { - fn embed_batch(&mut self, texts: &[&str]) -> Result>> { - self.embed_batch(texts) - } - - fn embed_one(&mut self, text: &str) -> Result> { - self.embed_one(text) - } - - fn token_count(&self, text: &str) -> usize { - self.token_count(text) - } - - fn dim(&self) -> usize { - EMBEDDING_DIM - } - - fn name(&self) -> &str { - "onnx:all-MiniLM-L6-v2" - } -} - -/// L2-normalize a vector. Returns a zero vector if input norm is zero. -fn normalize_vector(v: &[f32]) -> Vec { - let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); - if norm < f32::EPSILON { - return vec![0.0; v.len()]; - } - v.iter().map(|x| x / norm).collect() -} - -/// Compute SHA-256 hex digest of a file. -fn sha256_file(path: &Path) -> Result { - let mut file = std::fs::File::open(path)?; - let mut hasher = Sha256::new(); - let mut buffer = [0u8; 8192]; - loop { - let n = file.read(&mut buffer)?; - if n == 0 { - break; - } - hasher.update(&buffer[..n]); - } - Ok(format!("{:x}", hasher.finalize())) -} - -/// Verify that a file matches an expected SHA-256 hash. -fn verify_sha256(path: &Path, expected: &str) -> Result<()> { - let actual = sha256_file(path)?; - if actual != expected { - bail!( - "SHA-256 mismatch for {}: expected {expected}, got {actual}", - path.display() - ); - } - Ok(()) -} - -/// Compute SHA-256 hex digest of a byte slice. -#[cfg(test)] -fn sha256_bytes(data: &[u8]) -> String { - let mut hasher = Sha256::new(); - hasher.update(data); - format!("{:x}", hasher.finalize()) -} - -/// Download a file from `url` to `dest`, optionally verifying SHA-256. -/// Retries once on failure. -fn download_file(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> { - fn try_download(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> { - info!("downloading {} -> {}", url, dest.display()); - - let resp = ureq::get(url) - .call() - .with_context(|| format!("HTTP GET {url}"))?; - - let total_size: u64 = resp - .header("Content-Length") - .and_then(|s| s.parse().ok()) - .unwrap_or(0); - - let pb = ProgressBar::new(total_size); - pb.set_style( - ProgressStyle::with_template( - "{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})", - ) - .unwrap() - .progress_chars("=>-"), - ); - pb.set_message( - dest.file_name() - .unwrap_or_default() - .to_string_lossy() - .into_owned(), - ); - - let mut reader = resp.into_reader(); - let mut file = std::fs::File::create(dest)?; - let mut buffer = [0u8; 8192]; - loop { - let n = reader.read(&mut buffer)?; - if n == 0 { - break; - } - std::io::Write::write_all(&mut file, &buffer[..n])?; - pb.inc(n as u64); - } - pb.finish_with_message("done"); - - // Verify hash if provided. - if let Some(expected) = expected_sha256 { - let actual = sha256_file(dest)?; - if actual != expected { - let _ = std::fs::remove_file(dest); - bail!( - "SHA-256 mismatch for {}: expected {expected}, got {actual}", - dest.display() - ); - } - } - - Ok(()) - } - - // Try once, retry on failure. - match try_download(url, dest, expected_sha256) { - Ok(()) => Ok(()), - Err(first_err) => { - tracing::warn!("download failed, retrying: {first_err:#}"); - let _ = std::fs::remove_file(dest); - try_download(url, dest, expected_sha256) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sha256_verification() { - let data = b"hello world"; - let hash = sha256_bytes(data); - assert_eq!( - hash, - "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9" - ); - } - - #[test] - fn test_normalize_vector() { - let v = vec![3.0, 4.0]; - let n = normalize_vector(&v); - assert_eq!(n.len(), 2); - // Should be [0.6, 0.8]. - assert!((n[0] - 0.6).abs() < 1e-6); - assert!((n[1] - 0.8).abs() < 1e-6); - // L2 norm should be ~1.0. - let norm: f32 = n.iter().map(|x| x * x).sum::().sqrt(); - assert!((norm - 1.0).abs() < 1e-6); - } - - #[test] - fn test_normalize_zero_vector() { - let v = vec![0.0, 0.0, 0.0]; - let n = normalize_vector(&v); - assert!(n.iter().all(|x| *x == 0.0)); - } - - #[test] - #[ignore] - fn test_embed_smoke() { - let dir = tempfile::tempdir().unwrap(); - let mut embedder = Embedder::new(dir.path()).unwrap(); - let vec = embedder.embed_one("hello world").unwrap(); - assert_eq!(vec.len(), EMBEDDING_DIM); - let norm: f32 = vec.iter().map(|x| x * x).sum::(); - assert!((norm - 1.0).abs() < 0.01); - } -} diff --git a/src/indexer.rs b/src/indexer.rs index ed35ec9..d382945 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -10,8 +10,8 @@ use tracing::info; use crate::chunker::{chunk_markdown, split_oversized_chunks}; use crate::config::Config; use crate::docid::generate_docid; -use crate::embedder::Embedder; use crate::graph::extract_wikilink_targets; +use crate::llm::EmbedModel; use crate::store::{FileRecord, Store}; /// Summary of an indexing run. @@ -276,7 +276,7 @@ pub fn index_file( content: &str, content_hash: &str, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, vault_path: &Path, config: &Config, ) -> Result { @@ -440,7 +440,19 @@ pub fn run_index(vault_path: &Path, config: &Config, rebuild: bool) -> Result Result { run_index_inner(vault_path, config, store, embedder, rebuild) @@ -464,7 +476,7 @@ fn run_index_inner( vault_path: &Path, config: &Config, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, rebuild: bool, ) -> Result { let start = Instant::now(); @@ -622,7 +634,7 @@ fn run_index_inner( if vectors.is_empty() { continue; } - let dim = 384; + let dim = embedder.dim(); let mut centroid = vec![0.0f32; dim]; for v in vectors { for (i, val) in v.iter().enumerate() { diff --git a/src/lib.rs b/src/lib.rs index bb22f42..5ab1ab8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,12 @@ pub mod chunker; pub mod config; pub mod context; pub mod docid; -pub mod embedder; pub mod fts; pub mod fusion; pub mod graph; pub mod indexer; pub mod links; -pub mod model; +pub mod llm; pub mod placement; pub mod profile; pub mod search; diff --git a/src/llm.rs b/src/llm.rs new file mode 100644 index 0000000..f4c818f --- /dev/null +++ b/src/llm.rs @@ -0,0 +1,1838 @@ +use std::io::Read; +use std::path::{Path, PathBuf}; + +use anyhow::{Result, bail}; +use indicatif::{ProgressBar, ProgressStyle}; +use sha2::{Digest, Sha256}; + +use anyhow::Context as _; +use candle_core::{D, DType, Device, IndexOp, Tensor}; +use candle_nn::{Embedding, Module}; + +// ── Device selection ───────────────────────────────────────────────────────── + +/// Select best available device: Metal on macOS (with `metal` feature), CPU elsewhere. +fn select_device() -> Result { + #[cfg(feature = "metal")] + { + if let Ok(device) = Device::new_metal(0) { + return Ok(device); + } + } + Ok(Device::Cpu) +} + +// ── Prompt format ──────────────────────────────────────────────────────────── + +/// Model-family-specific prompt templates for embedding models. +#[derive(Debug, Clone)] +pub enum PromptFormat { + /// Google embeddinggemma family: uses `search_query:` / `search_document:` prefixes. + EmbeddingGemma, + /// Qwen embedding family: uses `Instruct:` / `Query:` format. + QwenEmbedding, + /// No special formatting — pass text as-is. + Raw, +} + +impl PromptFormat { + /// Auto-detect prompt format from a GGUF filename. + pub fn detect(filename: &str) -> Self { + let lower = filename.to_lowercase(); + if lower.contains("embeddinggemma") { + Self::EmbeddingGemma + } else if lower.contains("qwen") && lower.contains("embed") { + Self::QwenEmbedding + } else { + Self::Raw + } + } + + /// Format text for a search query. + pub fn format_query(&self, query: &str) -> String { + match self { + Self::EmbeddingGemma => format!("search_query: {query}"), + Self::QwenEmbedding => { + format!("Instruct: Retrieve relevant passages\nQuery: {query}") + } + Self::Raw => query.to_string(), + } + } + + /// Format text for a document to be indexed. + pub fn format_document(&self, title: &str, text: &str) -> String { + match self { + Self::EmbeddingGemma => format!("search_document: {title} {text}"), + Self::QwenEmbedding | Self::Raw => format!("{title}\n{text}"), + } + } +} + +// ── Types ──────────────────────────────────────────────────────────────────── + +/// Classified intent of an incoming search query. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QueryIntent { + /// User wants a precise fact or term match. + Exact, + /// User wants related ideas and concepts. + Conceptual, + /// User wants to explore connections between entities. + Relationship, + /// User is browsing without a clear target. + Exploratory, +} + +/// Output produced by an orchestrator model for a query. +#[derive(Debug, Clone)] +pub struct OrchestrationResult { + /// Classified query intent. + pub intent: QueryIntent, + /// Query string(s) to actually run (original + any expansions). + pub expansions: Vec, +} + +/// Per-lane weights for the RRF fusion step. +#[derive(Debug, Clone)] +pub struct LaneWeights { + pub semantic: f64, + pub fts: f64, + pub graph: f64, + pub rerank: f64, +} + +impl LaneWeights { + /// Map a classified intent to recommended lane weights. + pub fn from_intent(intent: &QueryIntent) -> Self { + match intent { + QueryIntent::Exact => Self { + fts: 1.5, + semantic: 0.6, + graph: 0.6, + rerank: 0.8, + }, + QueryIntent::Conceptual => Self { + semantic: 1.2, + fts: 0.8, + graph: 1.0, + rerank: 1.2, + }, + QueryIntent::Relationship => Self { + graph: 1.5, + semantic: 0.8, + fts: 0.8, + rerank: 1.0, + }, + QueryIntent::Exploratory => Self { + semantic: 1.0, + fts: 1.0, + graph: 0.8, + rerank: 1.0, + }, + } + } + + /// Weights used when no intelligence layer is available (legacy mode). + pub fn default_no_intelligence() -> Self { + Self { + semantic: 1.0, + fts: 1.0, + graph: 0.8, + rerank: 0.0, + } + } +} + +// ── Traits ─────────────────────────────────────────────────────────────────── + +/// Embedding backend — converts text into dense float vectors. +pub trait EmbedModel: Send { + /// Embed a batch of texts in one call. + fn embed_batch(&mut self, texts: &[&str]) -> Result>>; + + /// Convenience wrapper for a single text. + fn embed_one(&mut self, text: &str) -> Result> { + let mut results = self.embed_batch(&[text])?; + results + .pop() + .ok_or_else(|| anyhow::anyhow!("embed_batch returned empty results")) + } + + /// Approximate token count for `text` (used for chunk-size budgeting). + fn token_count(&self, text: &str) -> usize; + + /// Dimensionality of vectors produced by this model. + fn dim(&self) -> usize; +} + +// Blanket impl: `Box` itself implements `EmbedModel`. +// This lets `Arc>>` callers pass +// `&mut *guard` (which is `&mut Box`) to any +// function taking `&mut impl EmbedModel`. +impl EmbedModel for Box { + fn embed_batch(&mut self, texts: &[&str]) -> Result>> { + (**self).embed_batch(texts) + } + + fn embed_one(&mut self, text: &str) -> Result> { + (**self).embed_one(text) + } + + fn token_count(&self, text: &str) -> usize { + (**self).token_count(text) + } + + fn dim(&self) -> usize { + (**self).dim() + } +} + +/// Cross-encoder reranker — scores a (query, document) pair. +pub trait RerankModel: Send { + /// Return a relevance score in [0.0, 1.0]. + fn rerank_score(&mut self, query: &str, document: &str) -> Result; +} + +/// Orchestrator — interprets a query and produces an enriched search plan. +pub trait OrchestratorModel: Send { + fn orchestrate(&mut self, query: &str) -> Result; +} + +// ── MockLlm ────────────────────────────────────────────────────────────────── + +/// Deterministic in-process implementation of all three traits. +/// Suitable for unit tests and CI runs — no model files required. +pub struct MockLlm { + dim: usize, +} + +impl MockLlm { + pub fn new(dim: usize) -> Self { + Self { dim } + } + + /// Produce a deterministic L2-normalised vector from `text` via SHA-256. + pub fn hash_to_vector(&self, text: &str) -> Vec { + let mut raw: Vec = Vec::with_capacity(self.dim); + // Seed the first hash from the text itself, then chain hashes to fill + // vectors wider than 32 bytes (8 f32s per 256-bit hash). + let mut seed = text.to_owned(); + while raw.len() < self.dim { + let mut hasher = Sha256::new(); + hasher.update(seed.as_bytes()); + let hash = hasher.finalize(); + // Each hash gives 32 bytes → 8 f32 values. + for chunk in hash.chunks(4) { + if raw.len() >= self.dim { + break; + } + let bytes: [u8; 4] = chunk.try_into().expect("chunk is always 4 bytes"); + // Map u32 → [-1.0, 1.0] for a reasonable spread before normalisation. + let u = u32::from_le_bytes(bytes); + let f = (u as f32 / u32::MAX as f32) * 2.0 - 1.0; + raw.push(f); + } + // Next round: hash the previous hash digest (as hex) so values differ. + seed = format!("{:x}", { + let mut h2 = Sha256::new(); + h2.update(hash); + h2.finalize() + }); + } + + // L2-normalise so the mock behaves like a real embedding model. + let norm: f32 = raw.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + raw.iter_mut().for_each(|x| *x /= norm); + } + raw + } +} + +impl EmbedModel for MockLlm { + fn embed_batch(&mut self, texts: &[&str]) -> Result>> { + Ok(texts.iter().map(|t| self.hash_to_vector(t)).collect()) + } + + fn embed_one(&mut self, text: &str) -> Result> { + Ok(self.hash_to_vector(text)) + } + + fn token_count(&self, text: &str) -> usize { + text.len() / 4 + 1 + } + + fn dim(&self) -> usize { + self.dim + } +} + +impl RerankModel for MockLlm { + fn rerank_score(&mut self, query: &str, document: &str) -> Result { + // Deterministic score: Jaccard overlap of character 4-grams, clamped to [0,1]. + let ngrams = |s: &str| -> std::collections::HashSet { + s.chars() + .collect::>() + .windows(4) + .map(|w| w.iter().collect()) + .collect() + }; + + let q_set = ngrams(&query.to_lowercase()); + let d_set = ngrams(&document.to_lowercase()); + + if q_set.is_empty() && d_set.is_empty() { + return Ok(0.5); + } + + let intersection = q_set.intersection(&d_set).count(); + let union = q_set.union(&d_set).count(); + + let score = intersection as f32 / union as f32; + Ok(score.clamp(0.0, 1.0)) + } +} + +impl OrchestratorModel for MockLlm { + fn orchestrate(&mut self, query: &str) -> Result { + Ok(OrchestrationResult { + intent: QueryIntent::Exploratory, + expansions: vec![query.to_owned()], + }) + } +} + +// ── HuggingFace model download infrastructure ───────────────────────────────── + +/// Parsed HuggingFace model URI: "hf:org/repo/filename.gguf" +#[derive(Debug, Clone)] +pub struct HfModelUri { + pub repo: String, + pub filename: String, +} + +impl HfModelUri { + pub fn parse(uri: &str) -> Result { + let rest = uri + .strip_prefix("hf:") + .ok_or_else(|| anyhow::anyhow!("model URI must start with 'hf:', got: {uri}"))?; + let last_slash = rest.rfind('/').ok_or_else(|| { + anyhow::anyhow!("model URI must be 'hf:org/repo/file.gguf', got: {uri}") + })?; + let repo = &rest[..last_slash]; + let filename = &rest[last_slash + 1..]; + if repo.is_empty() || filename.is_empty() || !repo.contains('/') { + bail!("invalid model URI format: {uri}"); + } + Ok(Self { + repo: repo.to_string(), + filename: filename.to_string(), + }) + } + + pub fn download_url(&self) -> String { + format!( + "https://huggingface.co/{}/resolve/main/{}", + self.repo, self.filename + ) + } + + /// Local cache path: models_dir/repo--filename (slashes replaced with --) + pub fn cache_path(&self, models_dir: &Path) -> PathBuf { + let safe_name = format!("{}--{}", self.repo.replace('/', "--"), self.filename); + models_dir.join(safe_name) + } +} + +/// Download a file with progress bar and optional SHA256 verification. Retries once on failure. +pub fn download_model(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> { + fn try_download(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> { + tracing::info!("downloading {} -> {}", url, dest.display()); + + let resp = ureq::get(url) + .call() + .map_err(|e| anyhow::anyhow!("HTTP GET {url}: {e}"))?; + + let total_size: u64 = resp + .header("Content-Length") + .and_then(|v| v.parse().ok()) + .unwrap_or(0); + + let pb = ProgressBar::new(total_size); + pb.set_style( + ProgressStyle::with_template( + "{msg} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})", + ) + .unwrap_or_else(|_| ProgressStyle::default_bar()) + .progress_chars("=>-"), + ); + pb.set_message(format!( + "downloading {}", + dest.file_name().and_then(|n| n.to_str()).unwrap_or("model") + )); + + // Write to a temp file alongside dest, then rename for crash safety. + let tmp_path = dest.with_extension("tmp"); + { + let mut file = std::fs::File::create(&tmp_path) + .map_err(|e| anyhow::anyhow!("creating {}: {e}", tmp_path.display()))?; + let mut reader = resp.into_reader(); + let mut buffer = [0u8; 8192]; + loop { + let n = reader.read(&mut buffer)?; + if n == 0 { + break; + } + std::io::Write::write_all(&mut file, &buffer[..n])?; + pb.inc(n as u64); + } + } + pb.finish_with_message("done"); + + // Verify hash if provided. + if let Some(expected) = expected_sha256 { + let actual = sha256_file(&tmp_path)?; + if actual != expected { + let _ = std::fs::remove_file(&tmp_path); + bail!( + "SHA-256 mismatch for {}: expected {expected}, got {actual}", + dest.display() + ); + } + } + + std::fs::rename(&tmp_path, dest).map_err(|e| anyhow::anyhow!("renaming temp file: {e}"))?; + + Ok(()) + } + + // Try once, retry on failure. + match try_download(url, dest, expected_sha256) { + Ok(()) => Ok(()), + Err(first_err) => { + tracing::warn!("download failed, retrying: {first_err:#}"); + let _ = std::fs::remove_file(dest); + try_download(url, dest, expected_sha256) + } + } +} + +/// Compute SHA-256 hex digest of a file. +fn sha256_file(path: &Path) -> Result { + let mut file = std::fs::File::open(path)?; + let mut hasher = Sha256::new(); + let mut buffer = [0u8; 8192]; + loop { + let n = file.read(&mut buffer)?; + if n == 0 { + break; + } + hasher.update(&buffer[..n]); + } + Ok(format!("{:x}", hasher.finalize())) +} + +/// Ensure a model is present locally, downloading if not cached. +pub fn ensure_model(uri: &HfModelUri, models_dir: &Path) -> Result { + let path = uri.cache_path(models_dir); + if !path.exists() { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + download_model(&uri.download_url(), &path, None)?; + } + Ok(path) +} + +/// Default model URIs for the intelligence layer. +pub struct ModelDefaults { + pub embed_uri: String, + pub embed_dim: usize, + pub rerank_uri: String, + pub expand_uri: String, +} + +impl Default for ModelDefaults { + fn default() -> Self { + Self { + embed_uri: "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf".into(), + embed_dim: 256, + rerank_uri: "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf" + .into(), + expand_uri: "hf:Qwen/Qwen3-0.6B-GGUF/qwen3-0.6b-q8_0.gguf".into(), + } + } +} + +// ── CandleEmbed — GGUF embedding model via candle ────────────────────────── + +/// Quantized matrix multiplication wrapper (mirrors candle-transformers pattern). +#[derive(Debug, Clone)] +struct CandleQMatMul { + inner: candle_core::quantized::QMatMul, +} + +impl CandleQMatMul { + fn from_qtensor(qtensor: candle_core::quantized::QTensor) -> candle_core::Result { + let inner = candle_core::quantized::QMatMul::from_qtensor(qtensor)?; + Ok(Self { inner }) + } + + fn forward(&self, xs: &Tensor) -> candle_core::Result { + self.inner.forward(xs) + } +} + +/// Single transformer layer for the embedding model. +#[derive(Debug, Clone)] +struct EmbedLayer { + attention_wq: CandleQMatMul, + attention_wk: CandleQMatMul, + attention_wv: CandleQMatMul, + attention_wo: CandleQMatMul, + attention_q_norm: candle_transformers::quantized_nn::RmsNorm, + attention_k_norm: candle_transformers::quantized_nn::RmsNorm, + attention_norm: candle_transformers::quantized_nn::RmsNorm, + post_attention_norm: candle_transformers::quantized_nn::RmsNorm, + ffn_norm: candle_transformers::quantized_nn::RmsNorm, + post_ffn_norm: candle_transformers::quantized_nn::RmsNorm, + ffn_gate: CandleQMatMul, + ffn_up: CandleQMatMul, + ffn_down: CandleQMatMul, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + q_dim: usize, + rotary_sin: Tensor, + rotary_cos: Tensor, +} + +impl EmbedLayer { + /// Bidirectional forward pass — no causal mask, no KV cache. + fn forward(&self, x: &Tensor) -> candle_core::Result { + let (b_sz, seq_len, _) = x.dims3()?; + + // --- Attention block --- + let residual = x; + let x = self.attention_norm.forward(x)?; + + let q = self.attention_wq.forward(&x)?; + let k = self.attention_wk.forward(&x)?; + let v = self.attention_wv.forward(&x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.attention_q_norm.forward(&q.contiguous()?)?; + let k = self.attention_k_norm.forward(&k.contiguous()?)?; + + // Apply rotary embeddings (truncated to seq_len). + let q = Self::apply_rotary(&q, &self.rotary_cos, &self.rotary_sin, seq_len)?; + let k = Self::apply_rotary(&k, &self.rotary_cos, &self.rotary_sin, seq_len)?; + + // Repeat KV heads for GQA. + let n_rep = self.n_head / self.n_kv_head; + let k = candle_transformers::utils::repeat_kv(k, n_rep)?; + let v = candle_transformers::utils::repeat_kv(v, n_rep)?; + + // Scaled dot-product attention — BIDIRECTIONAL (no mask). + let scale = 1.0 / (self.head_dim as f64).sqrt(); + let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + + let attn_output = attn_output + .transpose(1, 2)? + .reshape((b_sz, seq_len, self.q_dim))?; + let attn_output = self.attention_wo.forward(&attn_output)?; + let x = self.post_attention_norm.forward(&attn_output)?; + let x = (x + residual)?; + + // --- FFN block --- + let residual = &x; + let h = self.ffn_norm.forward(&x)?; + let gate = self.ffn_gate.forward(&h)?; + let up = self.ffn_up.forward(&h)?; + let h = (candle_nn::ops::silu(&gate)? * up)?; + let h = self.ffn_down.forward(&h)?; + let h = self.post_ffn_norm.forward(&h)?; + h + residual + } + + /// Apply rotary embeddings to a [batch, heads, seq, dim] tensor. + fn apply_rotary( + x: &Tensor, + cos: &Tensor, + sin: &Tensor, + seq_len: usize, + ) -> candle_core::Result { + let cos = cos.i(..seq_len)?.unsqueeze(0)?.unsqueeze(0)?; + let sin = sin.i(..seq_len)?.unsqueeze(0)?.unsqueeze(0)?; + let dim = x.dim(D::Minus1)?; + let half = dim / 2; + let x1 = x.narrow(D::Minus1, 0, half)?; + let x2 = x.narrow(D::Minus1, half, half)?; + let rotated = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let out = (x.broadcast_mul(&cos)? + rotated.broadcast_mul(&sin)?)?; + Ok(out) + } +} + +/// GGUF embedding model loaded via candle. +/// +/// Loads a quantized Gemma-family embedding model (e.g., embeddinggemma-300M) +/// from a GGUF file and produces dense float vectors via bidirectional attention +/// + mean pooling + L2 normalization. +pub struct CandleEmbed { + layers: Vec, + tok_embeddings: Embedding, + norm: candle_transformers::quantized_nn::RmsNorm, + embedding_length: usize, + tokenizer: tokenizers::Tokenizer, + device: Device, + dim: usize, + prompt_format: PromptFormat, +} + +impl std::fmt::Debug for CandleEmbed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CandleEmbed") + .field("dim", &self.dim) + .field("embedding_length", &self.embedding_length) + .field("num_layers", &self.layers.len()) + .field("prompt_format", &self.prompt_format) + .finish() + } +} + +impl CandleEmbed { + /// Load a GGUF embedding model from `models_dir`. + /// + /// Steps: + /// 1. Resolve model URI (from config override or `ModelDefaults`) + /// 2. `ensure_model()` to download if needed + /// 3. Load tokenizer (try same repo's tokenizer.json, then repo without -GGUF suffix) + /// 4. Load GGUF and build layer structs for bidirectional embedding + /// 5. Detect prompt format from filename + pub fn new(models_dir: &Path, config: &crate::config::Config) -> Result { + let defaults = ModelDefaults::default(); + let uri_str = config + .models + .embed + .as_deref() + .unwrap_or(&defaults.embed_uri); + let uri = HfModelUri::parse(uri_str)?; + let model_path = ensure_model(&uri, models_dir)?; + + // Load tokenizer: try from the same HF repo, then from the non-GGUF variant. + let tokenizer = Self::load_tokenizer(&uri, models_dir)?; + + // Detect prompt format from filename. + let prompt_format = PromptFormat::detect(&uri.filename); + + // Target output dimensionality. + let dim = defaults.embed_dim; + + // Load GGUF and build model. + let device = select_device()?; + let (layers, tok_embeddings, norm, embedding_length) = + Self::load_gguf(&model_path, &device)?; + + tracing::info!( + "loaded CandleEmbed: {} layers, embedding_length={}, target_dim={}, device={:?}", + layers.len(), + embedding_length, + dim, + device + ); + + Ok(Self { + layers, + tok_embeddings, + norm, + embedding_length, + tokenizer, + device, + dim, + prompt_format, + }) + } + + /// Try to load tokenizer.json from the same HF repo, or from repo without "-GGUF" suffix. + fn load_tokenizer(uri: &HfModelUri, models_dir: &Path) -> Result { + // Try 1: tokenizer.json from the same repo. + let tok_uri = HfModelUri { + repo: uri.repo.clone(), + filename: "tokenizer.json".to_string(), + }; + let tok_path = tok_uri.cache_path(models_dir); + if tok_path.exists() { + return tokenizers::Tokenizer::from_file(&tok_path).map_err(|e| { + anyhow::anyhow!("loading tokenizer from {}: {e}", tok_path.display()) + }); + } + + // Try 2: download from the same repo. + if let Ok(p) = ensure_model(&tok_uri, models_dir) { + return tokenizers::Tokenizer::from_file(&p) + .map_err(|e| anyhow::anyhow!("loading tokenizer from {}: {e}", p.display())); + } + + // Try 3: non-GGUF variant of the repo (e.g., "org/model-GGUF" -> "org/model"). + let base_repo = uri.repo.trim_end_matches("-GGUF").to_string(); + if base_repo != uri.repo { + let base_tok_uri = HfModelUri { + repo: base_repo, + filename: "tokenizer.json".to_string(), + }; + if let Ok(p) = ensure_model(&base_tok_uri, models_dir) { + return tokenizers::Tokenizer::from_file(&p) + .map_err(|e| anyhow::anyhow!("loading tokenizer from {}: {e}", p.display())); + } + } + + bail!( + "could not find or download tokenizer for model repo '{}'", + uri.repo + ); + } + + /// Load GGUF file and construct layer structs for bidirectional embedding. + fn load_gguf( + path: &Path, + device: &Device, + ) -> Result<( + Vec, + Embedding, + candle_transformers::quantized_nn::RmsNorm, + usize, + )> { + use candle_core::quantized::gguf_file; + + let mut file = std::fs::File::open(path) + .map_err(|e| anyhow::anyhow!("opening GGUF {}: {e}", path.display()))?; + let ct = gguf_file::Content::read(&mut file) + .map_err(|e| anyhow::anyhow!("reading GGUF {}: {e}", path.display()))?; + + // Detect architecture prefix (same probe as candle-transformers quantized_gemma3). + let prefix = ["gemma3", "gemma2", "gemma", "gemma-embedding"] + .iter() + .find(|p| { + ct.metadata + .contains_key(&format!("{}.attention.head_count", p)) + }) + .copied() + .unwrap_or("gemma3"); + + let md_get = |s: &str| -> Result<&gguf_file::Value> { + let key = format!("{prefix}.{s}"); + ct.metadata + .get(&key) + .ok_or_else(|| anyhow::anyhow!("cannot find {key} in GGUF metadata")) + }; + + let head_count = md_get("attention.head_count")? + .to_u32() + .map_err(|e| anyhow::anyhow!("{e}"))? as usize; + let head_count_kv = md_get("attention.head_count_kv")? + .to_u32() + .map_err(|e| anyhow::anyhow!("{e}"))? as usize; + let block_count = md_get("block_count")? + .to_u32() + .map_err(|e| anyhow::anyhow!("{e}"))? as usize; + let embedding_length = md_get("embedding_length")? + .to_u32() + .map_err(|e| anyhow::anyhow!("{e}"))? as usize; + let key_length = md_get("attention.key_length")? + .to_u32() + .map_err(|e| anyhow::anyhow!("{e}"))? as usize; + let rms_norm_eps = md_get("attention.layer_norm_rms_epsilon")? + .to_f32() + .map_err(|e| anyhow::anyhow!("{e}"))? as f64; + let rope_freq_base = md_get("rope.freq_base") + .and_then(|v| v.to_f32().map_err(|e| anyhow::anyhow!("{e}"))) + .unwrap_or(10_000.0); + + let q_dim = head_count * key_length; + + // Build rotary embedding tables (shared by all layers for the base freq). + let max_seq_len: usize = 8192; // Sufficient for embedding inputs. + let (rotary_sin, rotary_cos) = + Self::build_rotary_tables(key_length, rope_freq_base, max_seq_len, device)?; + + // Load token embeddings. + let tok_embd = ct + .tensor(&mut file, "token_embd.weight", device) + .map_err(|e| anyhow::anyhow!("loading token_embd.weight: {e}"))?; + let tok_embd_deq = tok_embd + .dequantize(device) + .map_err(|e| anyhow::anyhow!("dequantizing token_embd: {e}"))?; + let tok_embeddings = Embedding::new(tok_embd_deq, embedding_length); + + // Final norm. + let norm_qt = ct + .tensor(&mut file, "output_norm.weight", device) + .map_err(|e| anyhow::anyhow!("loading output_norm.weight: {e}"))?; + let norm = candle_transformers::quantized_nn::RmsNorm::from_qtensor(norm_qt, rms_norm_eps) + .map_err(|e| anyhow::anyhow!("creating RmsNorm: {e}"))?; + + // Load transformer layers. + let mut layers = Vec::with_capacity(block_count); + for idx in 0..block_count { + let p = format!("blk.{idx}"); + + // Helper: load a quantized weight tensor as QMatMul. + macro_rules! load_q { + ($name:expr) => {{ + let full = format!("{}.{}", p, $name); + let qt = ct + .tensor(&mut file, &full, device) + .map_err(|e| anyhow::anyhow!("loading {full}: {e}"))?; + CandleQMatMul::from_qtensor(qt) + .map_err(|e| anyhow::anyhow!("QMatMul for {full}: {e}"))? + }}; + } + + // Helper: load a norm weight tensor as RmsNorm. + macro_rules! load_norm { + ($name:expr) => {{ + let full = format!("{}.{}", p, $name); + let qt = ct + .tensor(&mut file, &full, device) + .map_err(|e| anyhow::anyhow!("loading {full}: {e}"))?; + candle_transformers::quantized_nn::RmsNorm::from_qtensor(qt, rms_norm_eps) + .map_err(|e| anyhow::anyhow!("RmsNorm for {full}: {e}"))? + }}; + } + + layers.push(EmbedLayer { + attention_wq: load_q!("attn_q.weight"), + attention_wk: load_q!("attn_k.weight"), + attention_wv: load_q!("attn_v.weight"), + attention_wo: load_q!("attn_output.weight"), + attention_q_norm: load_norm!("attn_q_norm.weight"), + attention_k_norm: load_norm!("attn_k_norm.weight"), + attention_norm: load_norm!("attn_norm.weight"), + post_attention_norm: load_norm!("post_attention_norm.weight"), + ffn_norm: load_norm!("ffn_norm.weight"), + post_ffn_norm: load_norm!("post_ffw_norm.weight"), + ffn_gate: load_q!("ffn_gate.weight"), + ffn_up: load_q!("ffn_up.weight"), + ffn_down: load_q!("ffn_down.weight"), + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: key_length, + q_dim, + rotary_sin: rotary_sin.clone(), + rotary_cos: rotary_cos.clone(), + }); + } + + Ok((layers, tok_embeddings, norm, embedding_length)) + } + + /// Build sin/cos rotary embedding tables of shape [max_seq_len, head_dim]. + fn build_rotary_tables( + head_dim: usize, + freq_base: f32, + max_seq_len: usize, + device: &Device, + ) -> Result<(Tensor, Tensor)> { + let half = head_dim / 2; + let theta: Vec = (0..half) + .map(|i| 1.0 / freq_base.powf(i as f32 / half as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device) + .map_err(|e| anyhow::anyhow!("rotary theta: {e}"))?; + let positions = Tensor::arange(0, max_seq_len as u32, device) + .map_err(|e| anyhow::anyhow!("rotary positions: {e}"))? + .to_dtype(DType::F32) + .map_err(|e| anyhow::anyhow!("rotary positions dtype: {e}"))?; + // [max_seq_len, half] + let freqs = positions + .unsqueeze(1) + .map_err(|e| anyhow::anyhow!("rotary unsqueeze: {e}"))? + .broadcast_mul(&theta.unsqueeze(0).map_err(|e| anyhow::anyhow!("{e}"))?) + .map_err(|e| anyhow::anyhow!("rotary freqs: {e}"))?; + // Duplicate to [max_seq_len, head_dim] to match x1,x2 concatenation. + let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1) + .map_err(|e| anyhow::anyhow!("rotary cat: {e}"))?; + let sin = freqs + .sin() + .map_err(|e| anyhow::anyhow!("rotary sin: {e}"))?; + let cos = freqs + .cos() + .map_err(|e| anyhow::anyhow!("rotary cos: {e}"))?; + Ok((sin, cos)) + } + + /// Run a bidirectional forward pass and return the mean-pooled, truncated, + /// L2-normalized embedding. + fn embed_text(&self, text: &str) -> Result> { + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?; + let token_ids = encoding.get_ids(); + if token_ids.is_empty() { + bail!("tokenizer returned empty token sequence"); + } + + let input = Tensor::new(token_ids, &self.device) + .map_err(|e| anyhow::anyhow!("creating input tensor: {e}"))? + .unsqueeze(0) + .map_err(|e| anyhow::anyhow!("unsqueeze: {e}"))?; + + // Token embeddings, scaled by sqrt(embedding_length) (Gemma convention). + let mut hidden = self + .tok_embeddings + .forward(&input) + .map_err(|e| anyhow::anyhow!("token embedding forward: {e}"))?; + hidden = (hidden * (self.embedding_length as f64).sqrt()) + .map_err(|e| anyhow::anyhow!("scaling embeddings: {e}"))?; + + // Forward through all transformer layers (bidirectional — no causal mask). + for layer in &self.layers { + hidden = layer + .forward(&hidden) + .map_err(|e| anyhow::anyhow!("layer forward: {e}"))?; + } + + // Final layer norm. + hidden = self + .norm + .forward(&hidden) + .map_err(|e| anyhow::anyhow!("final norm: {e}"))?; + + // Mean pool across sequence dimension: [1, seq_len, hidden] -> [1, hidden]. + let seq_len = hidden + .dim(1) + .map_err(|e| anyhow::anyhow!("getting seq dim: {e}"))?; + let pooled = (hidden.sum(1).map_err(|e| anyhow::anyhow!("sum: {e}"))? / (seq_len as f64)) + .map_err(|e| anyhow::anyhow!("mean div: {e}"))?; + + // Squeeze batch dimension: [1, hidden] -> [hidden]. + let pooled = pooled + .squeeze(0) + .map_err(|e| anyhow::anyhow!("squeeze: {e}"))?; + + // Truncate to target dimensionality. + let full_dim = pooled + .dim(0) + .map_err(|e| anyhow::anyhow!("dim check: {e}"))?; + let truncated = if full_dim > self.dim { + pooled + .narrow(0, 0, self.dim) + .map_err(|e| anyhow::anyhow!("truncate: {e}"))? + } else { + pooled + }; + + // L2 normalize. + let norm_val = truncated + .sqr() + .map_err(|e| anyhow::anyhow!("sqr: {e}"))? + .sum_all() + .map_err(|e| anyhow::anyhow!("sum_all: {e}"))? + .sqrt() + .map_err(|e| anyhow::anyhow!("sqrt: {e}"))?; + let norm_scalar: f32 = norm_val + .to_scalar() + .map_err(|e| anyhow::anyhow!("norm scalar: {e}"))?; + + let normalized = if norm_scalar > 0.0 { + (truncated / norm_scalar as f64).map_err(|e| anyhow::anyhow!("normalize: {e}"))? + } else { + truncated + }; + + let vec: Vec = normalized + .to_vec1() + .map_err(|e| anyhow::anyhow!("to_vec1: {e}"))?; + Ok(vec) + } +} + +impl EmbedModel for CandleEmbed { + fn embed_batch(&mut self, texts: &[&str]) -> Result>> { + // Process texts sequentially — candle quantized ops are single-threaded. + texts.iter().map(|t| self.embed_text(t)).collect() + } + + fn embed_one(&mut self, text: &str) -> Result> { + self.embed_text(text) + } + + fn token_count(&self, text: &str) -> usize { + self.tokenizer + .encode(text, false) + .map(|enc| enc.get_ids().len()) + .unwrap_or(text.len() / 4 + 1) + } + + fn dim(&self) -> usize { + self.dim + } +} + +// ── Heuristic orchestrator ─────────────────────────────────────────────────── + +/// Heuristic orchestrator — no LLM, fast path when intelligence is off. +pub fn heuristic_orchestrate(query: &str) -> OrchestrationResult { + let trimmed = query.trim(); + + // Exact: docids (#abc123) or ticket IDs (ABC-1234) + if trimmed.starts_with('#') && trimmed.len() <= 8 { + return OrchestrationResult { + intent: QueryIntent::Exact, + expansions: vec![trimmed.to_string()], + }; + } + // Ticket ID pattern: PREFIX-1234 + if trimmed.contains('-') + && let Some(prefix) = trimmed.split('-').next() + && prefix.chars().all(|c| c.is_ascii_uppercase()) + { + let after = trimmed.split('-').nth(1).unwrap_or(""); + if after.chars().all(|c| c.is_ascii_digit()) && !after.is_empty() { + return OrchestrationResult { + intent: QueryIntent::Exact, + expansions: vec![trimmed.to_string()], + }; + } + } + + // Relationship: "who" queries + let lower = trimmed.to_lowercase(); + if lower.starts_with("who ") || lower.contains(" who ") { + return OrchestrationResult { + intent: QueryIntent::Relationship, + expansions: vec![trimmed.to_string()], + }; + } + + // Default: exploratory with word splitting for multi-word queries + let words: Vec<&str> = trimmed.split_whitespace().collect(); + let mut expansions = vec![trimmed.to_string()]; + if words.len() > 2 { + let stopwords = [ + "how", "does", "the", "a", "an", "is", "are", "was", "to", "in", "on", "for", "with", + "what", "when", "where", + ]; + for word in &words { + if word.len() > 2 && !stopwords.contains(&word.to_lowercase().as_str()) { + expansions.push(word.to_string()); + } + } + } + + OrchestrationResult { + intent: QueryIntent::Exploratory, + expansions, + } +} + +// ── Orchestration JSON parsing ──────────────────────────────────────────────── + +/// Parse orchestration JSON from LLM output. +/// Handles: raw JSON, JSON embedded in text, and partial/malformed responses. +pub fn parse_orchestration_json(text: &str) -> Result { + let json_str = extract_json_object(text) + .ok_or_else(|| anyhow::anyhow!("no JSON object found in LLM response"))?; + + let parsed: serde_json::Value = + serde_json::from_str(json_str).with_context(|| "parsing orchestration JSON")?; + + let intent_str = parsed["intent"].as_str().unwrap_or("exploratory"); + let intent = match intent_str { + "exact" => QueryIntent::Exact, + "conceptual" => QueryIntent::Conceptual, + "relationship" => QueryIntent::Relationship, + _ => QueryIntent::Exploratory, + }; + + let expansions: Vec = parsed["expansions"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); + + if expansions.is_empty() { + anyhow::bail!("no expansions in orchestration response"); + } + + Ok(OrchestrationResult { intent, expansions }) +} + +/// Extract the first JSON object ({...}) from text, handling nested braces. +fn extract_json_object(text: &str) -> Option<&str> { + let start = text.find('{')?; + let mut depth = 0; + for (i, b) in text[start..].bytes().enumerate() { + match b { + b'{' => depth += 1, + b'}' => { + depth -= 1; + if depth == 0 { + return Some(&text[start..start + i + 1]); + } + } + _ => {} + } + } + None +} + +// ── CandleOrchestrator — GGUF text generation via candle ───────────────────── + +const ORCHESTRATOR_SYSTEM_PROMPT: &str = r#"You are a search query analyzer. Given a user's search query, classify it and expand it. + +Return JSON with: +- "intent": one of "exact", "conceptual", "relationship", "exploratory" +- "expansions": 2-4 alternative phrasings (always include the original query first) + +Be concise. Only return the JSON object."#; + +/// Quantized Qwen3 model for query orchestration and expansion. +/// +/// Loads a Qwen3 GGUF model and performs autoregressive generation to classify +/// queries and produce expansions. Falls back to `heuristic_orchestrate` if +/// generation or JSON parsing fails. +pub struct CandleOrchestrator { + model: candle_transformers::models::quantized_qwen3::ModelWeights, + tokenizer: tokenizers::Tokenizer, + device: Device, +} + +impl std::fmt::Debug for CandleOrchestrator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CandleOrchestrator") + .field("device", &self.device) + .finish() + } +} + +impl CandleOrchestrator { + /// Load a Qwen3 GGUF model for orchestration from `models_dir`. + /// + /// Steps: + /// 1. Resolve model URI (from config override or `ModelDefaults`) + /// 2. `ensure_model()` to download if needed + /// 3. Load tokenizer from the model repo (or the non-GGUF base repo) + /// 4. Load GGUF via `ModelWeights::from_gguf()` + pub fn new(models_dir: &Path, config: &crate::config::Config) -> Result { + let defaults = ModelDefaults::default(); + let uri_str = config + .models + .expand + .as_deref() + .unwrap_or(&defaults.expand_uri); + let uri = HfModelUri::parse(uri_str)?; + let model_path = ensure_model(&uri, models_dir)?; + + // Load tokenizer (same strategy as CandleEmbed). + let tokenizer = Self::load_tokenizer(&uri, models_dir)?; + + let device = select_device()?; + + // Load GGUF model. + let mut file = std::fs::File::open(&model_path) + .map_err(|e| anyhow::anyhow!("opening GGUF {}: {e}", model_path.display()))?; + let ct = candle_core::quantized::gguf_file::Content::read(&mut file) + .map_err(|e| anyhow::anyhow!("reading GGUF {}: {e}", model_path.display()))?; + let model = candle_transformers::models::quantized_qwen3::ModelWeights::from_gguf( + ct, &mut file, &device, + ) + .map_err(|e| anyhow::anyhow!("loading Qwen3 model weights: {e}"))?; + + tracing::info!( + "loaded CandleOrchestrator from {}, device={:?}", + uri_str, + device + ); + + Ok(Self { + model, + tokenizer, + device, + }) + } + + /// Try to load tokenizer.json from the same HF repo, or from the non-GGUF base repo. + fn load_tokenizer(uri: &HfModelUri, models_dir: &Path) -> Result { + // Try 1: tokenizer.json from the same repo. + let tok_uri = HfModelUri { + repo: uri.repo.clone(), + filename: "tokenizer.json".to_string(), + }; + let tok_path = tok_uri.cache_path(models_dir); + if tok_path.exists() { + return tokenizers::Tokenizer::from_file(&tok_path).map_err(|e| { + anyhow::anyhow!("loading tokenizer from {}: {e}", tok_path.display()) + }); + } + + // Try 2: download from the same repo. + if let Ok(p) = ensure_model(&tok_uri, models_dir) { + return tokenizers::Tokenizer::from_file(&p) + .map_err(|e| anyhow::anyhow!("loading tokenizer from {}: {e}", p.display())); + } + + // Try 3: non-GGUF variant of the repo (e.g., "Qwen/Qwen3-0.6B-GGUF" -> "Qwen/Qwen3-0.6B"). + let base_repo = uri.repo.trim_end_matches("-GGUF").to_string(); + if base_repo != uri.repo { + let base_tok_uri = HfModelUri { + repo: base_repo, + filename: "tokenizer.json".to_string(), + }; + if let Ok(p) = ensure_model(&base_tok_uri, models_dir) { + return tokenizers::Tokenizer::from_file(&p) + .map_err(|e| anyhow::anyhow!("loading tokenizer from {}: {e}", p.display())); + } + } + + bail!( + "could not find or download tokenizer for model repo '{}'", + uri.repo + ); + } + + /// Format a chat prompt in Qwen3 ChatML format. + fn format_prompt(query: &str) -> String { + format!( + "<|im_start|>system\n{ORCHESTRATOR_SYSTEM_PROMPT}<|im_end|>\n\ + <|im_start|>user\n{query}<|im_end|>\n\ + <|im_start|>assistant\n" + ) + } + + /// Run autoregressive generation (greedy decode) up to `max_tokens`. + /// Returns the generated text (excluding the prompt). + fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result { + self.model.clear_kv_cache(); + + let encoding = self + .tokenizer + .encode(prompt, true) + .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?; + let prompt_tokens = encoding.get_ids(); + if prompt_tokens.is_empty() { + bail!("tokenizer returned empty token sequence"); + } + + // Determine EOS token ID. + let eos_token_id = self + .tokenizer + .token_to_id("<|im_end|>") + .or_else(|| self.tokenizer.token_to_id("<|endoftext|>")) + .unwrap_or(151643); // Qwen3 default EOS + + // Process the prompt in a single forward pass. + let input = Tensor::new(prompt_tokens, &self.device)? + .unsqueeze(0) + .map_err(|e| anyhow::anyhow!("unsqueeze prompt: {e}"))?; + let logits = self + .model + .forward(&input, 0) + .map_err(|e| anyhow::anyhow!("forward pass (prompt): {e}"))?; + + // Get the last token's logits and pick argmax. + let logits = logits + .to_dtype(DType::F32) + .map_err(|e| anyhow::anyhow!("logits dtype: {e}"))?; + let next_token = logits + .i(0)? + .argmax(D::Minus1) + .map_err(|e| anyhow::anyhow!("argmax: {e}"))? + .to_scalar::() + .map_err(|e| anyhow::anyhow!("scalar: {e}"))?; + + let mut generated_tokens: Vec = vec![next_token]; + let mut offset = prompt_tokens.len(); + + if next_token == eos_token_id { + // Model produced EOS immediately. + return Ok(String::new()); + } + + // Autoregressive loop. + for _ in 1..max_tokens { + let input = Tensor::new(&[*generated_tokens.last().unwrap()], &self.device)? + .unsqueeze(0) + .map_err(|e| anyhow::anyhow!("unsqueeze step: {e}"))?; + let logits = self + .model + .forward(&input, offset) + .map_err(|e| anyhow::anyhow!("forward pass (step): {e}"))?; + offset += 1; + + let logits = logits + .to_dtype(DType::F32) + .map_err(|e| anyhow::anyhow!("logits dtype: {e}"))?; + let token = logits + .i(0)? + .argmax(D::Minus1) + .map_err(|e| anyhow::anyhow!("argmax: {e}"))? + .to_scalar::() + .map_err(|e| anyhow::anyhow!("scalar: {e}"))?; + + if token == eos_token_id { + break; + } + generated_tokens.push(token); + } + + let text = self + .tokenizer + .decode(&generated_tokens, true) + .map_err(|e| anyhow::anyhow!("decoding generated tokens: {e}"))?; + Ok(text) + } +} + +impl OrchestratorModel for CandleOrchestrator { + fn orchestrate(&mut self, query: &str) -> Result { + let prompt = Self::format_prompt(query); + + match self.generate(&prompt, 256) { + Ok(text) => match parse_orchestration_json(&text) { + Ok(result) => Ok(result), + Err(e) => { + tracing::warn!( + "orchestrator JSON parse failed, falling back to heuristic: {e:#}" + ); + Ok(heuristic_orchestrate(query)) + } + }, + Err(e) => { + tracing::warn!("orchestrator generation failed, falling back to heuristic: {e:#}"); + Ok(heuristic_orchestrate(query)) + } + } + } +} + +// ── CandleRerank — GGUF cross-encoder reranker via candle ───────────────────── + +/// Format query+document for cross-encoder reranking. +pub fn format_reranker_input(query: &str, document: &str) -> String { + format!( + "<|im_start|>system\nJudge whether the document is relevant to the search query. \ + Respond only with \"Yes\" or \"No\".<|im_end|>\n\ + <|im_start|>user\nSearch query: {query}\nDocument: {document}<|im_end|>\n\ + <|im_start|>assistant\n" + ) +} + +/// Quantized Qwen3 cross-encoder for reranking search results. +/// +/// Loads a Qwen3-Reranker GGUF model and scores (query, document) pairs by +/// running a single forward pass and extracting Yes/No logit probabilities. +/// Unlike `CandleOrchestrator`, this does NOT do autoregressive generation — +/// just one pass through the full input to get logits at the last position. +pub struct CandleRerank { + model: candle_transformers::models::quantized_qwen3::ModelWeights, + tokenizer: tokenizers::Tokenizer, + device: Device, + yes_token_id: u32, + no_token_id: u32, +} + +impl std::fmt::Debug for CandleRerank { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CandleRerank") + .field("device", &self.device) + .field("yes_token_id", &self.yes_token_id) + .field("no_token_id", &self.no_token_id) + .finish() + } +} + +impl CandleRerank { + /// Load a Qwen3-Reranker GGUF model from `models_dir`. + /// + /// Steps: + /// 1. Resolve model URI (from config override or `ModelDefaults::default().rerank_uri`) + /// 2. `ensure_model()` to download if needed + /// 3. Load tokenizer from the model repo (or the non-GGUF base repo) + /// 4. Load GGUF via `ModelWeights::from_gguf()` + /// 5. Look up "Yes" and "No" token IDs from the tokenizer + pub fn new(models_dir: &Path, config: &crate::config::Config) -> Result { + let defaults = ModelDefaults::default(); + let uri_str = config + .models + .rerank + .as_deref() + .unwrap_or(&defaults.rerank_uri); + let uri = HfModelUri::parse(uri_str)?; + let model_path = ensure_model(&uri, models_dir)?; + + // Load tokenizer (same strategy as CandleOrchestrator). + let tokenizer = Self::load_tokenizer(&uri, models_dir)?; + + // Look up Yes/No token IDs. + let yes_token_id = tokenizer + .token_to_id("Yes") + .ok_or_else(|| anyhow::anyhow!("tokenizer has no 'Yes' token"))?; + let no_token_id = tokenizer + .token_to_id("No") + .ok_or_else(|| anyhow::anyhow!("tokenizer has no 'No' token"))?; + + let device = select_device()?; + + // Load GGUF model. + let mut file = std::fs::File::open(&model_path) + .map_err(|e| anyhow::anyhow!("opening GGUF {}: {e}", model_path.display()))?; + let ct = candle_core::quantized::gguf_file::Content::read(&mut file) + .map_err(|e| anyhow::anyhow!("reading GGUF {}: {e}", model_path.display()))?; + let model = candle_transformers::models::quantized_qwen3::ModelWeights::from_gguf( + ct, &mut file, &device, + ) + .map_err(|e| anyhow::anyhow!("loading Qwen3 reranker model weights: {e}"))?; + + tracing::info!( + "loaded CandleRerank from {}, device={:?}, yes_id={}, no_id={}", + uri_str, + device, + yes_token_id, + no_token_id + ); + + Ok(Self { + model, + tokenizer, + device, + yes_token_id, + no_token_id, + }) + } + + /// Try to load tokenizer.json from the same HF repo, or from the non-GGUF base repo. + fn load_tokenizer(uri: &HfModelUri, models_dir: &Path) -> Result { + // Try 1: tokenizer.json from the same repo. + let tok_uri = HfModelUri { + repo: uri.repo.clone(), + filename: "tokenizer.json".to_string(), + }; + let tok_path = tok_uri.cache_path(models_dir); + if tok_path.exists() { + return tokenizers::Tokenizer::from_file(&tok_path).map_err(|e| { + anyhow::anyhow!("loading tokenizer from {}: {e}", tok_path.display()) + }); + } + + // Try 2: download from the same repo. + if let Ok(p) = ensure_model(&tok_uri, models_dir) { + return tokenizers::Tokenizer::from_file(&p) + .map_err(|e| anyhow::anyhow!("loading tokenizer from {}: {e}", p.display())); + } + + // Try 3: non-GGUF variant of the repo. + let base_repo = uri.repo.trim_end_matches("-GGUF").to_string(); + if base_repo != uri.repo { + let base_tok_uri = HfModelUri { + repo: base_repo, + filename: "tokenizer.json".to_string(), + }; + if let Ok(p) = ensure_model(&base_tok_uri, models_dir) { + return tokenizers::Tokenizer::from_file(&p) + .map_err(|e| anyhow::anyhow!("loading tokenizer from {}: {e}", p.display())); + } + } + + bail!( + "could not find or download tokenizer for model repo '{}'", + uri.repo + ); + } +} + +impl RerankModel for CandleRerank { + fn rerank_score(&mut self, query: &str, document: &str) -> Result { + self.model.clear_kv_cache(); + + let input_text = format_reranker_input(query, document); + + let encoding = self + .tokenizer + .encode(input_text.as_str(), true) + .map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?; + let token_ids = encoding.get_ids(); + if token_ids.is_empty() { + bail!("tokenizer returned empty token sequence"); + } + + // Single forward pass through the full input (no autoregressive generation). + let input = Tensor::new(token_ids, &self.device)? + .unsqueeze(0) + .map_err(|e| anyhow::anyhow!("unsqueeze input: {e}"))?; + let logits = self + .model + .forward(&input, 0) + .map_err(|e| anyhow::anyhow!("forward pass: {e}"))?; + + // logits shape: [1, seq_len, vocab_size] or [1, vocab_size] (last position). + // Extract logits for the last position. + let logits = logits + .to_dtype(DType::F32) + .map_err(|e| anyhow::anyhow!("logits dtype: {e}"))?; + let last_logits = logits + .i(0) + .map_err(|e| anyhow::anyhow!("batch index: {e}"))?; + + // Extract Yes/No logits. + let yes_logit: f32 = last_logits + .i(self.yes_token_id as usize) + .map_err(|e| anyhow::anyhow!("yes logit index: {e}"))? + .to_scalar() + .map_err(|e| anyhow::anyhow!("yes logit scalar: {e}"))?; + let no_logit: f32 = last_logits + .i(self.no_token_id as usize) + .map_err(|e| anyhow::anyhow!("no logit index: {e}"))? + .to_scalar() + .map_err(|e| anyhow::anyhow!("no logit scalar: {e}"))?; + + // Softmax over Yes/No to get probability. + let max_logit = yes_logit.max(no_logit); + let yes_exp = (yes_logit - max_logit).exp(); + let no_exp = (no_logit - max_logit).exp(); + let score = yes_exp / (yes_exp + no_exp); + + Ok(score) + } +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mock_embed_deterministic() { + let mut mock = MockLlm::new(256); + let v1 = mock.embed_one("hello").unwrap(); + let v2 = mock.embed_one("hello").unwrap(); + assert_eq!(v1.len(), 256); + assert_eq!(v1, v2, "same input must produce same output"); + } + + #[test] + fn test_mock_embed_different_inputs() { + let mut mock = MockLlm::new(256); + let v1 = mock.embed_one("hello").unwrap(); + let v2 = mock.embed_one("world").unwrap(); + assert_ne!(v1, v2, "different inputs should produce different vectors"); + } + + #[test] + fn test_mock_embed_batch() { + let mut mock = MockLlm::new(256); + let vecs = mock.embed_batch(&["a", "b", "c"]).unwrap(); + assert_eq!(vecs.len(), 3); + assert!(vecs.iter().all(|v| v.len() == 256)); + } + + #[test] + fn test_mock_embed_normalized() { + let mut mock = MockLlm::new(256); + let v = mock.embed_one("test").unwrap(); + let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + assert!( + (norm - 1.0).abs() < 0.01, + "mock vectors should be L2-normalized" + ); + } + + #[test] + fn test_mock_rerank() { + let mut mock = MockLlm::new(256); + let score = mock.rerank_score("query", "document text").unwrap(); + assert!((0.0..=1.0).contains(&score)); + } + + #[test] + fn test_mock_orchestrate() { + let mut mock = MockLlm::new(256); + let result = mock.orchestrate("how does auth work").unwrap(); + assert_eq!(result.intent, QueryIntent::Exploratory); + assert!(!result.expansions.is_empty()); + assert_eq!(result.expansions[0], "how does auth work"); + } + + #[test] + fn test_mock_rerank_empty_query() { + let mut mock = MockLlm::new(256); + let score = mock.rerank_score("", "document text").unwrap(); + assert_eq!(score, 0.0, "empty query should score 0.0"); + } + + #[test] + fn test_lane_weights_from_intent() { + let exact = LaneWeights::from_intent(&QueryIntent::Exact); + assert!(exact.fts > exact.semantic, "exact intent should favor FTS"); + + let conceptual = LaneWeights::from_intent(&QueryIntent::Conceptual); + assert!( + conceptual.semantic > conceptual.fts, + "conceptual should favor semantic" + ); + + let relationship = LaneWeights::from_intent(&QueryIntent::Relationship); + assert!( + relationship.graph > relationship.semantic, + "relationship should favor graph" + ); + } + + #[test] + fn test_parse_hf_uri() { + let uri = "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf"; + let parsed = HfModelUri::parse(uri).unwrap(); + assert_eq!(parsed.repo, "ggml-org/embeddinggemma-300M-GGUF"); + assert_eq!(parsed.filename, "embeddinggemma-300M-Q8_0.gguf"); + assert_eq!( + parsed.download_url(), + "https://huggingface.co/ggml-org/embeddinggemma-300M-GGUF/resolve/main/embeddinggemma-300M-Q8_0.gguf" + ); + } + + #[test] + fn test_parse_hf_uri_invalid() { + assert!(HfModelUri::parse("not-a-hf-uri").is_err()); + assert!(HfModelUri::parse("hf:only-repo").is_err()); + } + + #[test] + fn test_model_cache_path() { + let uri = HfModelUri::parse("hf:ggml-org/embeddinggemma-300M-GGUF/model.gguf").unwrap(); + let cache_dir = std::path::Path::new("/tmp/models"); + let path = uri.cache_path(cache_dir); + assert_eq!( + path, + cache_dir.join("ggml-org--embeddinggemma-300M-GGUF--model.gguf") + ); + } + + #[test] + fn test_model_defaults() { + let defaults = ModelDefaults::default(); + assert!(defaults.embed_uri.starts_with("hf:")); + assert_eq!(defaults.embed_dim, 256); + } + + // ── CandleEmbed / PromptFormat tests ──────────────────────────────────── + + #[test] + fn test_candle_embed_struct_exists() { + fn assert_embed_model(_e: &E) {} + let mock = MockLlm::new(256); + assert_embed_model(&mock); + // CandleEmbed also implements EmbedModel — verified at compile time. + // We can't instantiate CandleEmbed without a real GGUF model, + // but the trait bound compiles. + } + + #[test] + fn test_prompt_format_embeddinggemma_query() { + let fmt = PromptFormat::detect("embeddinggemma-300M-Q8_0.gguf"); + let formatted = fmt.format_query("how does auth work"); + assert!(formatted.contains("search_query")); + assert!(formatted.contains("how does auth work")); + } + + #[test] + fn test_prompt_format_embeddinggemma_document() { + let fmt = PromptFormat::detect("embeddinggemma-300M-Q8_0.gguf"); + let formatted = fmt.format_document("Note Title", "some content"); + assert!(formatted.contains("Note Title")); + assert!(formatted.contains("some content")); + assert!(formatted.contains("search_document")); + } + + #[test] + fn test_prompt_format_unknown_model() { + let fmt = PromptFormat::detect("unknown-model.gguf"); + let formatted = fmt.format_query("test query"); + assert_eq!(formatted, "test query"); + } + + #[test] + fn test_prompt_format_qwen_embedding() { + let fmt = PromptFormat::detect("qwen-embed-v2.gguf"); + let formatted = fmt.format_query("find me something"); + assert!(formatted.contains("Instruct:")); + assert!(formatted.contains("Query:")); + assert!(formatted.contains("find me something")); + } + + #[test] + fn test_prompt_format_qwen_document() { + let fmt = PromptFormat::detect("qwen-embed-v2.gguf"); + let formatted = fmt.format_document("Title", "Body text"); + assert_eq!(formatted, "Title\nBody text"); + } + + #[test] + fn test_prompt_format_raw_document() { + let fmt = PromptFormat::detect("random-model.gguf"); + let formatted = fmt.format_document("Title", "Body"); + assert_eq!(formatted, "Title\nBody"); + } + + #[test] + fn test_select_device_returns_cpu_by_default() { + // Without the `metal` feature, select_device should return CPU. + let device = select_device().unwrap(); + // On CI/test without metal feature, this should be CPU. + // With metal feature on macOS, it could be Metal — both are valid. + let _ = device; // Just verify it doesn't error. + } + + // ── heuristic_orchestrate tests ────────────────────────────────────────── + + #[test] + fn test_heuristic_orchestrate_single_word() { + let result = heuristic_orchestrate("auth"); + assert_eq!(result.intent, QueryIntent::Exploratory); + assert_eq!(result.expansions, vec!["auth"]); + } + + #[test] + fn test_heuristic_orchestrate_multi_word() { + let result = heuristic_orchestrate("how does auth work"); + assert_eq!(result.intent, QueryIntent::Exploratory); + assert!( + result + .expansions + .contains(&"how does auth work".to_string()) + ); + assert!(result.expansions.len() > 1); + } + + #[test] + fn test_heuristic_orchestrate_docid() { + let result = heuristic_orchestrate("#ab12cd"); + assert_eq!(result.intent, QueryIntent::Exact); + } + + #[test] + fn test_heuristic_orchestrate_ticket_id() { + let result = heuristic_orchestrate("BRE-1234"); + assert_eq!(result.intent, QueryIntent::Exact); + } + + #[test] + fn test_heuristic_orchestrate_who_query() { + let result = heuristic_orchestrate("who works on checkout"); + assert_eq!(result.intent, QueryIntent::Relationship); + } + + // ── parse_orchestration_json tests ─────────────────────────────────────── + + #[test] + fn test_parse_orchestration_json_valid() { + let json = + r#"{"intent": "conceptual", "expansions": ["auth work", "authentication design"]}"#; + let result = parse_orchestration_json(json).unwrap(); + assert_eq!(result.intent, QueryIntent::Conceptual); + assert_eq!(result.expansions.len(), 2); + } + + #[test] + fn test_parse_orchestration_json_with_surrounding_text() { + let text = + "Here is the analysis:\n{\"intent\": \"exact\", \"expansions\": [\"BRE-1234\"]}\nDone."; + let result = parse_orchestration_json(text).unwrap(); + assert_eq!(result.intent, QueryIntent::Exact); + } + + #[test] + fn test_parse_orchestration_json_invalid() { + let bad = "not json at all"; + assert!(parse_orchestration_json(bad).is_err()); + } + + #[test] + fn test_parse_orchestration_json_unknown_intent() { + let json = r#"{"intent": "unknown_type", "expansions": ["query"]}"#; + let result = parse_orchestration_json(json).unwrap(); + assert_eq!(result.intent, QueryIntent::Exploratory); + } + + #[test] + fn test_extract_json_object_nested() { + let text = r#"prefix {"a": {"b": 1}} suffix"#; + let extracted = extract_json_object(text).unwrap(); + assert_eq!(extracted, r#"{"a": {"b": 1}}"#); + } + + #[test] + fn test_extract_json_object_none() { + assert!(extract_json_object("no braces here").is_none()); + } + + #[test] + fn test_extract_json_object_unclosed() { + assert!(extract_json_object("{ open but never closed").is_none()); + } + + #[test] + fn test_parse_orchestration_json_empty_expansions() { + let json = r#"{"intent": "exact", "expansions": []}"#; + assert!(parse_orchestration_json(json).is_err()); + } + + #[test] + fn test_parse_orchestration_json_missing_expansions() { + let json = r#"{"intent": "exact"}"#; + assert!(parse_orchestration_json(json).is_err()); + } + + // ── CandleOrchestrator tests ───────────────────────────────────────────── + + #[test] + fn test_candle_orchestrator_format_prompt() { + let prompt = CandleOrchestrator::format_prompt("how does auth work"); + assert!(prompt.contains("<|im_start|>system")); + assert!(prompt.contains("<|im_end|>")); + assert!(prompt.contains("<|im_start|>user")); + assert!(prompt.contains("how does auth work")); + assert!(prompt.contains("<|im_start|>assistant")); + } + + #[test] + fn test_candle_orchestrator_implements_trait() { + // Compile-time check: CandleOrchestrator implements OrchestratorModel. + fn assert_orchestrator() {} + assert_orchestrator::(); + } + + // ── CandleRerank tests ────────────────────────────────────────────────── + + #[test] + fn test_format_reranker_input() { + let formatted = format_reranker_input("auth system", "The auth module handles OAuth"); + assert!(formatted.contains("auth system")); + assert!(formatted.contains("The auth module handles OAuth")); + assert!(formatted.contains("Respond only with")); + } + + #[test] + fn test_candle_rerank_trait_compliance() { + // Verify MockLlm still satisfies RerankModel. + fn assert_rerank(_r: &R) {} + let mock = MockLlm::new(256); + assert_rerank(&mock); + } +} diff --git a/src/main.rs b/src/main.rs index 2df5985..7d2872b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use engraph::config; use engraph::indexer; -use engraph::model; use engraph::profile; use engraph::search; use engraph::store; @@ -73,8 +72,20 @@ enum Command { path: Option, }, - /// Interactively configure vault profile. - Configure, + /// Configure engraph settings. + Configure { + /// Enable intelligence features. + #[arg(long, conflicts_with = "disable_intelligence")] + enable_intelligence: bool, + + /// Disable intelligence features. + #[arg(long, conflicts_with = "enable_intelligence")] + disable_intelligence: bool, + + /// Override a model: --model embed|rerank|expand + #[arg(long, num_args = 2, value_names = &["TYPE", "URI"])] + model: Option>, + }, /// Manage embedding models. Models { @@ -207,6 +218,38 @@ enum ModelsAction { Info { name: String }, } +/// Prompt user to enable intelligence, download models if yes. +fn prompt_intelligence(data_dir: &std::path::Path) -> Result { + eprint!( + "\nEnable AI-powered search intelligence?\n\n\ + This downloads ~1.3GB of additional models for:\n\ + \x20 - Query expansion (rewrites your search into multiple variations)\n\ + \x20 - Result reranking (LLM scores each result for relevance)\n\n\ + Enable now? [y/N] " + ); + io::stderr().flush()?; + let mut answer = String::new(); + io::stdin().lock().read_line(&mut answer)?; + let enable = answer.trim().eq_ignore_ascii_case("y"); + + if enable { + let models_dir = data_dir.join("models"); + let defaults = engraph::llm::ModelDefaults::default(); + println!("Downloading intelligence models (~1.3GB)..."); + let rerank_uri = engraph::llm::HfModelUri::parse(&defaults.rerank_uri)?; + engraph::llm::ensure_model(&rerank_uri, &models_dir)?; + let expand_uri = engraph::llm::HfModelUri::parse(&defaults.expand_uri)?; + engraph::llm::ensure_model(&expand_uri, &models_dir)?; + println!("Done."); + } else { + println!( + "Intelligence disabled. You can enable later with: engraph configure --enable-intelligence" + ); + } + + Ok(enable) +} + /// Check whether an index has been built by looking for engraph.db in data_dir. fn index_exists(data_dir: &std::path::Path) -> bool { data_dir.join("engraph.db").exists() @@ -295,6 +338,13 @@ async fn main() -> Result<()> { } } + // First-run intelligence prompt (only if not yet configured) + if cfg.intelligence.is_none() { + let enable = prompt_intelligence(&data_dir)?; + cfg.intelligence = Some(enable); + cfg.save()?; + } + let result = indexer::run_index(&vault_path, &cfg, rebuild)?; println!( @@ -319,7 +369,7 @@ async fn main() -> Result<()> { std::process::exit(1); } - search::run_search(&query, cfg.top_n, cli.json, explain, &data_dir)?; + search::run_search(&query, cfg.top_n, cli.json, explain, &data_dir, &cfg)?; } Command::Status => { @@ -414,11 +464,74 @@ async fn main() -> Result<()> { println!(); println!("Wrote {}", data_dir.join("vault.toml").display()); + + // Intelligence onboarding (only if not yet configured) + if cfg.intelligence.is_none() { + let enable = prompt_intelligence(&data_dir)?; + cfg.intelligence = Some(enable); + cfg.save()?; + } } - Command::Configure => { + Command::Configure { + enable_intelligence, + disable_intelligence, + model, + } => { + let mut cfg = Config::load()?; + + if enable_intelligence { + cfg.intelligence = Some(true); + println!("Intelligence enabled. Models will be downloaded on first search."); + let models_dir = data_dir.join("models"); + let defaults = engraph::llm::ModelDefaults::default(); + println!("Downloading intelligence models (~1.3GB)..."); + let rerank_uri = engraph::llm::HfModelUri::parse( + cfg.models.rerank.as_deref().unwrap_or(&defaults.rerank_uri), + )?; + engraph::llm::ensure_model(&rerank_uri, &models_dir)?; + let expand_uri = engraph::llm::HfModelUri::parse( + cfg.models.expand.as_deref().unwrap_or(&defaults.expand_uri), + )?; + engraph::llm::ensure_model(&expand_uri, &models_dir)?; + println!("Done."); + } else if disable_intelligence { + cfg.intelligence = Some(false); + println!("Intelligence disabled. Models remain cached."); + } + + if let Some(parts) = model + && parts.len() == 2 + { + let model_type = &parts[0]; + let uri = &parts[1]; + engraph::llm::HfModelUri::parse(uri)?; + match model_type.as_str() { + "embed" => { + cfg.models.embed = Some(uri.clone()); + println!("Embedding model set to: {uri}"); + println!("Warning: Next 'engraph index' will re-embed your entire vault."); + } + "rerank" => { + cfg.models.rerank = Some(uri.clone()); + println!("Reranker model set to: {uri}"); + } + "expand" => { + cfg.models.expand = Some(uri.clone()); + println!("Expansion model set to: {uri}"); + } + other => { + anyhow::bail!( + "Unknown model type: {other}. Use: embed, rerank, or expand." + ); + } + } + } + + cfg.save()?; println!( - "Interactive configuration not yet implemented. Run 'engraph init' for auto-detection." + "Configuration saved to {}", + data_dir.join("config.toml").display() ); } @@ -716,7 +829,7 @@ async fn main() -> Result<()> { } ContextAction::Topic { query, budget } => { let models_dir = data_dir.join("models"); - let mut embedder = engraph::embedder::Embedder::new(&models_dir)?; + let mut embedder = engraph::llm::CandleEmbed::new(&models_dir, &cfg)?; let bundle = engraph::context::context_topic_with_search( ¶ms, @@ -769,7 +882,7 @@ async fn main() -> Result<()> { .ok_or_else(|| anyhow::anyhow!("No vault path in index."))?; let vault_path = PathBuf::from(&vault_path_str); let models_dir = data_dir.join("models"); - let mut embedder = engraph::embedder::Embedder::new(&models_dir)?; + let mut embedder = engraph::llm::CandleEmbed::new(&models_dir, &cfg)?; let profile = config::Config::load_vault_profile().ok().flatten(); match action { @@ -866,23 +979,23 @@ async fn main() -> Result<()> { } Command::Models { action } => { - let registry = model::ModelRegistry::default(); + let defaults = engraph::llm::ModelDefaults::default(); match action { ModelsAction::List => { println!("{:<30} {:>5} DESCRIPTION", "NAME", "DIM"); println!("{}", "-".repeat(70)); - for entry in ®istry.entries { - println!("{:<30} {:>5} {}", entry.name, entry.dim, entry.description); - } + let desc = "Default embedding model (GGUF)"; + println!( + "{:<30} {:>5} {}", + defaults.embed_uri, defaults.embed_dim, desc + ); } ModelsAction::Info { name } => { - if let Some(entry) = registry.get(&name) { - println!("Name: {}", entry.name); - println!("Format: {:?}", entry.format); - println!("Dimensions: {}", entry.dim); - println!("SHA-256: {}", entry.sha256); - println!("URL: {}", entry.url); - println!("Description: {}", entry.description); + if name == defaults.embed_uri { + println!("Name: {}", defaults.embed_uri); + println!("Format: GGUF"); + println!("Dimensions: {}", defaults.embed_dim); + println!("Description: Default embedding model (GGUF)"); } else { eprintln!("Unknown model: {name}"); eprintln!("Run 'engraph models list' to see available models."); diff --git a/src/model.rs b/src/model.rs deleted file mode 100644 index b8c966c..0000000 --- a/src/model.rs +++ /dev/null @@ -1,141 +0,0 @@ -use anyhow::Result; -use serde::{Deserialize, Serialize}; - -/// Trait for embedding backends. Any model that can embed text implements this. -pub trait ModelBackend { - fn embed_batch(&mut self, texts: &[&str]) -> Result>>; - fn embed_one(&mut self, text: &str) -> Result>; - fn token_count(&self, text: &str) -> usize; - fn dim(&self) -> usize; - fn name(&self) -> &str; -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum ModelFormat { - Onnx, - Gguf, - File, -} - -#[derive(Debug, Clone)] -pub struct ModelSpec { - pub format: ModelFormat, - pub name: String, - pub path: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelRegistryEntry { - pub name: String, - pub format: ModelFormat, - pub url: String, - pub sha256: String, - pub dim: usize, - pub description: String, -} - -pub struct ModelRegistry { - pub entries: Vec, -} - -impl Default for ModelRegistry { - fn default() -> Self { - Self { - entries: vec![ModelRegistryEntry { - name: "onnx:all-MiniLM-L6-v2".to_string(), - format: ModelFormat::Onnx, - url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".to_string(), - sha256: "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452".to_string(), - dim: 384, - description: "Lightweight general-purpose sentence embeddings".to_string(), - }], - } - } -} - -impl ModelRegistry { - pub fn get(&self, name: &str) -> Option<&ModelRegistryEntry> { - self.entries.iter().find(|e| e.name == name) - } -} - -pub fn parse_model_spec(spec: &str) -> ModelSpec { - if let Some(path) = spec.strip_prefix("file:") { - return ModelSpec { - format: ModelFormat::File, - name: spec.to_string(), - path: path.to_string(), - }; - } - if let Some((format_str, name)) = spec.split_once(':') { - let format = match format_str { - "onnx" => ModelFormat::Onnx, - "gguf" => ModelFormat::Gguf, - _ => ModelFormat::Onnx, - }; - ModelSpec { - format, - name: name.to_string(), - path: String::new(), - } - } else { - ModelSpec { - format: ModelFormat::Onnx, - name: spec.to_string(), - path: String::new(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_model_registry_default() { - let registry = ModelRegistry::default(); - assert_eq!(registry.entries.len(), 1); - let entry = ®istry.entries[0]; - assert_eq!(entry.name, "onnx:all-MiniLM-L6-v2"); - assert_eq!(entry.dim, 384); - assert_eq!(entry.format, ModelFormat::Onnx); - } - - #[test] - fn test_parse_model_spec_onnx() { - let spec = parse_model_spec("onnx:all-MiniLM-L6-v2"); - assert_eq!(spec.format, ModelFormat::Onnx); - assert_eq!(spec.name, "all-MiniLM-L6-v2"); - assert!(spec.path.is_empty()); - } - - #[test] - fn test_parse_model_spec_file() { - let spec = parse_model_spec("file:/path/to/model.onnx"); - assert_eq!(spec.format, ModelFormat::File); - assert_eq!(spec.name, "file:/path/to/model.onnx"); - assert_eq!(spec.path, "/path/to/model.onnx"); - } - - #[test] - fn test_parse_model_spec_bare() { - let spec = parse_model_spec("my-custom-model"); - assert_eq!(spec.format, ModelFormat::Onnx); - assert_eq!(spec.name, "my-custom-model"); - assert!(spec.path.is_empty()); - } - - #[test] - fn test_registry_get_existing() { - let registry = ModelRegistry::default(); - let entry = registry.get("onnx:all-MiniLM-L6-v2"); - assert!(entry.is_some()); - assert_eq!(entry.unwrap().dim, 384); - } - - #[test] - fn test_registry_get_missing() { - let registry = ModelRegistry::default(); - assert!(registry.get("nonexistent-model").is_none()); - } -} diff --git a/src/placement.rs b/src/placement.rs index 6bc6e89..8051029 100644 --- a/src/placement.rs +++ b/src/placement.rs @@ -1,6 +1,6 @@ use anyhow::Result; -use crate::embedder::Embedder; +use crate::llm::EmbedModel; use crate::profile::VaultProfile; use crate::store::Store; use crate::writer::split_frontmatter; @@ -46,7 +46,7 @@ pub fn place_note( hints: &PlacementHints, profile: Option<&VaultProfile>, store: &Store, - embedder: Option<&mut Embedder>, + embedder: Option<&mut impl EmbedModel>, ) -> Result { // Strategy A: Type-based rules if let Some(result) = try_type_rules(content, hints, profile) { @@ -234,7 +234,7 @@ fn looks_like_meeting_note(content: &str) -> bool { fn try_semantic_placement( content: &str, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, ) -> Result> { let centroids = store.get_folder_centroids()?; if centroids.is_empty() { @@ -383,6 +383,7 @@ pub fn strip_placement_frontmatter(content: &str) -> String { #[cfg(test)] mod tests { use super::*; + use crate::llm::MockLlm; use crate::profile::{ FolderMap, StructureDetection, StructureMethod, VaultProfile, VaultStats, }; @@ -408,7 +409,14 @@ mod tests { type_hint: None, tags: vec![], }; - let result = place_note("Some random note.", &hints, None, &store, None).unwrap(); + let result = place_note( + "Some random note.", + &hints, + None, + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::InboxFallback); assert_eq!(result.folder, "00-Inbox"); } @@ -421,7 +429,7 @@ mod tests { type_hint: Some("person".into()), tags: vec![], }; - let result = place_note("# John Doe", &hints, None, &store, None).unwrap(); + let result = place_note("# John Doe", &hints, None, &store, None::<&mut MockLlm>).unwrap(); assert_eq!(result.strategy, PlacementStrategy::InboxFallback); } @@ -437,7 +445,14 @@ mod tests { type_hint: Some("person".into()), tags: vec![], }; - let result = place_note("# John Doe", &hints, Some(&profile), &store, None).unwrap(); + let result = place_note( + "# John Doe", + &hints, + Some(&profile), + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::TypeRule); assert_eq!(result.folder, "03-Resources/People"); assert!(result.confidence > 0.9); @@ -455,7 +470,14 @@ mod tests { type_hint: Some("daily".into()), tags: vec![], }; - let result = place_note("Today's notes", &hints, Some(&profile), &store, None).unwrap(); + let result = place_note( + "Today's notes", + &hints, + Some(&profile), + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::TypeRule); assert_eq!(result.folder, "07-Daily"); } @@ -472,7 +494,14 @@ mod tests { type_hint: Some("workout".into()), tags: vec![], }; - let result = place_note("Leg day workout", &hints, Some(&profile), &store, None).unwrap(); + let result = place_note( + "Leg day workout", + &hints, + Some(&profile), + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::TypeRule); assert_eq!(result.folder, "02-Areas/Health"); } @@ -490,7 +519,14 @@ mod tests { tags: vec![], }; let content = "# Jane Smith\nRole: Engineering Manager\nCompany: Acme Corp"; - let result = place_note(content, &hints, Some(&profile), &store, None).unwrap(); + let result = place_note( + content, + &hints, + Some(&profile), + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::TypeRule); assert_eq!(result.folder, "03-Resources/People"); } @@ -510,7 +546,14 @@ mod tests { }; // Heading with 2 words but no Role: or Company: let content = "# Jane Smith\nJust some notes about a topic."; - let result = place_note(content, &hints, Some(&profile), &store, None).unwrap(); + let result = place_note( + content, + &hints, + Some(&profile), + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::InboxFallback); } @@ -526,7 +569,14 @@ mod tests { type_hint: None, tags: vec![], }; - let result = place_note("Random note", &hints, Some(&profile), &store, None).unwrap(); + let result = place_note( + "Random note", + &hints, + Some(&profile), + &store, + None::<&mut MockLlm>, + ) + .unwrap(); assert_eq!(result.strategy, PlacementStrategy::InboxFallback); assert_eq!(result.folder, "Inbox"); } @@ -570,7 +620,7 @@ mod tests { &hints, Some(&profile), &store, - None, + None::<&mut MockLlm>, ) .unwrap(); assert_eq!(result.strategy, PlacementStrategy::TypeRule); diff --git a/src/search.rs b/src/search.rs index f776cff..430d56d 100644 --- a/src/search.rs +++ b/src/search.rs @@ -4,11 +4,19 @@ use std::path::Path; use anyhow::{Context, Result}; use serde_json::json; -use crate::embedder::Embedder; use crate::fusion::{self, RankedResult}; use crate::graph; +use crate::llm::{self, EmbedModel, OrchestratorModel, RerankModel}; use crate::store::{Store, StoreStats}; +/// Compute cache key for orchestration results (SHA256 of query). +#[allow(dead_code)] +fn orchestration_cache_key(query: &str) -> String { + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(query.as_bytes()); + format!("{:x}", hash) +} + /// A single search result with metadata. pub struct SearchResult { pub score: f32, @@ -33,135 +41,199 @@ pub struct InternalSearchResult { pub struct SearchOutput { pub results: Vec, pub fused: Vec, + pub intent: Option, +} + +/// Configuration for the intelligence search pipeline. +pub struct SearchConfig<'a> { + pub orchestrator: Option<&'a mut dyn OrchestratorModel>, + pub reranker: Option<&'a mut dyn RerankModel>, + pub store: &'a Store, + pub rerank_candidates: usize, } /// Run hybrid search and return structured results (no I/O). /// Used by both `run_search` (CLI) and context engine. +/// +/// Thin wrapper around `search_with_intelligence` with no intelligence models, +/// preserving the existing heuristic-only behavior. pub fn search_internal( query: &str, top_n: usize, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, +) -> Result { + let mut config = SearchConfig { + orchestrator: None, + reranker: None, + store, + rerank_candidates: 30, + }; + search_with_intelligence(query, top_n, embedder, &mut config) +} + +/// Full intelligence search pipeline. +/// +/// 1. Orchestrate (intent + expansions + weights) — LLM if available, else heuristic. +/// 2. 3-lane retrieval per expanded query (semantic, FTS, graph). +/// 3. RRF Pass 1 with top candidates. +/// 4. Reranker scores each candidate (4th lane) if available. +/// 5. RRF Pass 2 with all 4 lanes for final ranking. +pub fn search_with_intelligence( + query: &str, + top_n: usize, + embedder: &mut impl EmbedModel, + config: &mut SearchConfig<'_>, ) -> Result { - // --- Semantic lane --- - let query_vec = embedder.embed_one(query).context("embedding query")?; - let tombstones = std::collections::HashSet::new(); - - // Request extra results to account for file-level dedup. - let raw_results = store.search_vec(&query_vec, top_n * 3, &tombstones)?; - - // Group semantic results by file_path, keeping best per file. - let mut sem_by_file: HashMap = HashMap::new(); - for (vector_id, distance) in raw_results { - if let Some(chunk) = store.get_chunk_by_vector_id(vector_id)? { - let (file_path, docid) = match store.get_file_by_id(chunk.file_id)? { + // --- Step 1: Orchestrate --- + let orchestration = match &mut config.orchestrator { + Some(orch) => orch.orchestrate(query)?, + None => llm::heuristic_orchestrate(query), + }; + let weights = llm::LaneWeights::from_intent(&orchestration.intent); + + // --- Step 2: Run 3-lane retrieval for EACH expanded query --- + let mut all_semantic: Vec = Vec::new(); + let mut all_fts: Vec = Vec::new(); + + for expanded_query in &orchestration.expansions { + // Semantic lane + let query_vec = embedder + .embed_one(expanded_query) + .context("embedding query")?; + let tombstones = std::collections::HashSet::new(); + let raw_results = config + .store + .search_vec(&query_vec, top_n * 3, &tombstones)?; + + // Group semantic results by file_path, keeping best per file. + let mut sem_by_file: HashMap = HashMap::new(); + for (vector_id, distance) in raw_results { + if let Some(chunk) = config.store.get_chunk_by_vector_id(vector_id)? { + let (file_path, docid) = match config.store.get_file_by_id(chunk.file_id)? { + Some(f) => (f.path, f.docid), + None => ("".to_string(), None), + }; + let score = (1.0 - distance) as f64; + let heading = if chunk.heading.is_empty() { + None + } else { + Some(chunk.heading) + }; + + let better = match sem_by_file.get(&file_path) { + Some(existing) => score > existing.score, + None => true, + }; + if better { + sem_by_file.insert( + file_path.clone(), + RankedResult { + file_path, + file_id: chunk.file_id, + score, + heading, + snippet: chunk.snippet, + docid, + }, + ); + } + } + } + all_semantic.extend(sem_by_file.into_values()); + + // FTS lane + let fts_raw = config + .store + .fts_search(expanded_query, top_n * 3) + .unwrap_or_default(); + + let mut fts_by_file: HashMap = HashMap::new(); + for fr in fts_raw { + let (file_path, docid) = match config.store.get_file_by_id(fr.file_id)? { Some(f) => (f.path, f.docid), - None => ("".to_string(), None), - }; - let score = (1.0 - distance) as f64; - let heading = if chunk.heading.is_empty() { - None - } else { - Some(chunk.heading) + None => continue, }; - // Keep the best-scoring chunk per file. - let better = match sem_by_file.get(&file_path) { - Some(existing) => score > existing.score, + let better = match fts_by_file.get(&file_path) { + Some(existing) => fr.score > existing.score, None => true, }; if better { - sem_by_file.insert( + fts_by_file.insert( file_path.clone(), RankedResult { file_path, - file_id: chunk.file_id, - score, - heading, - snippet: chunk.snippet, + file_id: fr.file_id, + score: fr.score, + heading: None, // FTS doesn't return headings + snippet: fr.snippet, docid, }, ); } } + all_fts.extend(fts_by_file.into_values()); } - // Sort semantic results by score descending for rank assignment. - let mut semantic_results: Vec = sem_by_file.into_values().collect(); - semantic_results.sort_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // --- FTS lane --- - let fts_raw = store.fts_search(query, top_n * 3).unwrap_or_default(); - - // Group FTS results by file_path, keeping best per file. - let mut fts_by_file: HashMap = HashMap::new(); - for fr in fts_raw { - let (file_path, docid) = match store.get_file_by_id(fr.file_id)? { - Some(f) => (f.path, f.docid), - None => continue, - }; - - let better = match fts_by_file.get(&file_path) { - Some(existing) => fr.score > existing.score, - None => true, - }; - if better { - fts_by_file.insert( - file_path.clone(), - RankedResult { - file_path, - file_id: fr.file_id, - score: fr.score, - heading: None, // FTS doesn't return headings - snippet: fr.snippet, - docid, - }, - ); - } - } - - let mut fts_results: Vec = fts_by_file.into_values().collect(); - fts_results.sort_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // --- Graph lane --- - // Combine seeds from semantic + FTS (deduplicated by file_path, take higher score) - let combined_seeds: Vec = { - let mut by_file: HashMap = HashMap::new(); - for r in semantic_results.iter().chain(fts_results.iter()) { - match by_file.get(&r.file_path) { - Some(existing) if r.score <= existing.score => {} - _ => { - by_file.insert(r.file_path.clone(), r.clone()); - } - } - } - by_file.into_values().collect() - }; + // Deduplicate across expanded queries (keep best score per file) + let semantic_results = dedup_by_file(all_semantic); + let fts_results = dedup_by_file(all_fts); + // --- Graph lane from combined seeds --- + let combined_seeds = merge_seeds(&semantic_results, &fts_results); let graph_results = - graph::graph_expand(store, &combined_seeds, query, 2, 20).unwrap_or_default(); + graph::graph_expand(config.store, &combined_seeds, query, 2, 20).unwrap_or_default(); - // --- RRF Fusion --- + // --- Step 3: RRF Pass 1 (3-lane) --- const RRF_K: usize = 60; - let fused = fusion::rrf_fuse( + let fused_pass1 = fusion::rrf_fuse( &[ - ("semantic", &semantic_results, 1.0), - ("fts", &fts_results, 1.0), - ("graph", &graph_results, 0.8), + ("semantic", &semantic_results, weights.semantic), + ("fts", &fts_results, weights.fts), + ("graph", &graph_results, weights.graph), ], RRF_K, ); + // --- Step 4: Reranker (4th lane) if available --- + let final_fused = if let Some(reranker) = &mut config.reranker { + let mut rerank_results: Vec = Vec::new(); + for candidate in fused_pass1.iter().take(config.rerank_candidates) { + let score = reranker + .rerank_score(query, &candidate.snippet) + .unwrap_or(0.0) as f64; + rerank_results.push(RankedResult { + file_path: candidate.file_path.clone(), + file_id: candidate.file_id, + score, + heading: candidate.heading.clone(), + snippet: candidate.snippet.clone(), + docid: candidate.docid.clone(), + }); + } + rerank_results.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // RRF Pass 2 (4-lane) + fusion::rrf_fuse( + &[ + ("semantic", &semantic_results, weights.semantic), + ("fts", &fts_results, weights.fts), + ("graph", &graph_results, weights.graph), + ("rerank", &rerank_results, weights.rerank), + ], + RRF_K, + ) + } else { + fused_pass1 + }; + // Convert fused results to InternalSearchResult, taking top_n. - let results: Vec = fused + let results: Vec = final_fused .iter() .take(top_n) .map(|f| InternalSearchResult { @@ -174,7 +246,45 @@ pub fn search_internal( }) .collect(); - Ok(SearchOutput { results, fused }) + Ok(SearchOutput { + results, + fused: final_fused, + intent: Some(orchestration.intent), + }) +} + +/// Deduplicate ranked results by file path, keeping the highest score per file. +fn dedup_by_file(results: Vec) -> Vec { + let mut by_file: HashMap = HashMap::new(); + for r in results { + let dominated = by_file + .get(&r.file_path) + .is_some_and(|existing| existing.score >= r.score); + if !dominated { + by_file.insert(r.file_path.clone(), r); + } + } + let mut deduped: Vec = by_file.into_values().collect(); + deduped.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + deduped +} + +/// Merge semantic and FTS seed results, keeping the highest score per file. +fn merge_seeds(semantic: &[RankedResult], fts: &[RankedResult]) -> Vec { + let mut by_file: HashMap = HashMap::new(); + for r in semantic.iter().chain(fts.iter()) { + let dominated = by_file + .get(&r.file_path) + .is_some_and(|existing| existing.score >= r.score); + if !dominated { + by_file.insert(r.file_path.clone(), r.clone()); + } + } + by_file.into_values().collect() } /// Run a search query and print results. @@ -188,9 +298,11 @@ pub fn run_search( json: bool, explain: bool, data_dir: &Path, + config: &crate::config::Config, ) -> Result<()> { let models_dir = data_dir.join("models"); - let mut embedder = Embedder::new(&models_dir).context("loading embedder")?; + let mut embedder = + crate::llm::CandleEmbed::new(&models_dir, config).context("loading embedder")?; let db_path = data_dir.join("engraph.db"); let store = Store::open(&db_path).context("opening store")?; @@ -212,7 +324,11 @@ pub fn run_search( let mut out = format_results(&results, json); if explain && !json { - let mut explain_out = String::from("\n--- Explain ---\n"); + let mut explain_out = String::new(); + if let Some(ref intent) = output.intent { + explain_out.push_str(&format!("Intent: {:?}\n\n", intent)); + } + explain_out.push_str("--- Explain ---\n"); for f in output.fused.iter().take(top_n) { explain_out.push_str(&format!("{}\n", f.file_path)); explain_out.push_str(&fusion::format_explain(f)); @@ -235,7 +351,14 @@ pub fn run_status(json: bool, data_dir: &Path) -> Result<()> { let model_name = "all-MiniLM-L6-v2"; - let output = format_status(&stats, index_size, model_name, json); + let config = crate::config::Config::load().unwrap_or_default(); + let intelligence = if config.intelligence_enabled() { + "enabled" + } else { + "disabled" + }; + + let output = format_status(&stats, index_size, model_name, intelligence, json); print!("{output}"); Ok(()) } @@ -295,7 +418,13 @@ pub fn format_results(results: &[SearchResult], json: bool) -> String { } /// Format status information for display (pure function, no I/O). -pub fn format_status(stats: &StoreStats, index_size: u64, model_name: &str, json: bool) -> String { +pub fn format_status( + stats: &StoreStats, + index_size: u64, + model_name: &str, + intelligence: &str, + json: bool, +) -> String { let vault = stats.vault_path.as_deref().unwrap_or(""); let last_indexed = stats.last_indexed_at.as_deref().unwrap_or("never"); @@ -308,6 +437,7 @@ pub fn format_status(stats: &StoreStats, index_size: u64, model_name: &str, json "last_indexed": last_indexed, "index_size": index_size, "model": model_name, + "intelligence": intelligence, }); if let (Some(edges), Some(wl), Some(mn)) = (stats.edge_count, stats.wikilink_count, stats.mention_count) @@ -336,11 +466,13 @@ pub fn format_status(stats: &StoreStats, index_size: u64, model_name: &str, json "Tombstones: {} (pending cleanup)\n\ Last index: {}\n\ Index size: {}\n\ - Model: {}\n", + Model: {}\n\ + Intelligence: {}\n", stats.tombstone_count, last_indexed, format_bytes(index_size), model_name, + intelligence, )); out } @@ -451,7 +583,7 @@ mod tests { wikilink_count: None, mention_count: None, }; - let output = format_status(&stats, 2_516_582, "all-MiniLM-L6-v2", false); + let output = format_status(&stats, 2_516_582, "all-MiniLM-L6-v2", "disabled", false); assert!(output.contains("/path/to/vault"), "missing vault path"); assert!(output.contains("42"), "missing file count"); @@ -460,6 +592,7 @@ mod tests { assert!(output.contains("2026-03-19 14:30:00"), "missing last index"); assert!(output.contains("2.4 MB"), "missing index size"); assert!(output.contains("all-MiniLM-L6-v2"), "missing model"); + assert!(output.contains("disabled"), "missing intelligence"); } #[test] @@ -474,7 +607,7 @@ mod tests { wikilink_count: None, mention_count: None, }; - let output = format_status(&stats, 2_516_582, "all-MiniLM-L6-v2", true); + let output = format_status(&stats, 2_516_582, "all-MiniLM-L6-v2", "enabled", true); let parsed: serde_json::Value = serde_json::from_str(&output).unwrap(); assert_eq!(parsed["vault"], "/path/to/vault"); @@ -484,6 +617,7 @@ mod tests { assert_eq!(parsed["last_indexed"], "2026-03-19 14:30:00"); assert_eq!(parsed["index_size"], 2_516_582); assert_eq!(parsed["model"], "all-MiniLM-L6-v2"); + assert_eq!(parsed["intelligence"], "enabled"); } #[test] @@ -505,4 +639,112 @@ mod tests { assert_eq!(format_bytes(1024 * 1024), "1.0 MB"); assert_eq!(format_bytes(2_516_582), "2.4 MB"); } + + #[test] + fn test_cache_key_deterministic() { + let key1 = super::orchestration_cache_key("how does auth work"); + let key2 = super::orchestration_cache_key("how does auth work"); + assert_eq!(key1, key2); + + let key3 = super::orchestration_cache_key("different query"); + assert_ne!(key1, key3); + } + + #[test] + fn test_search_output_has_intent() { + let output = SearchOutput { + results: vec![], + fused: vec![], + intent: Some(crate::llm::QueryIntent::Conceptual), + }; + assert_eq!(output.intent, Some(crate::llm::QueryIntent::Conceptual)); + } + + #[test] + fn test_search_output_intent_none() { + let output = SearchOutput { + results: vec![], + fused: vec![], + intent: None, + }; + assert!(output.intent.is_none()); + } + + #[test] + fn test_dedup_by_file_keeps_best() { + let results = vec![ + RankedResult { + file_path: "a.md".to_string(), + file_id: 1, + score: 0.5, + heading: None, + snippet: "low".to_string(), + docid: None, + }, + RankedResult { + file_path: "a.md".to_string(), + file_id: 1, + score: 0.9, + heading: None, + snippet: "high".to_string(), + docid: None, + }, + RankedResult { + file_path: "b.md".to_string(), + file_id: 2, + score: 0.7, + heading: None, + snippet: "only".to_string(), + docid: None, + }, + ]; + let deduped = dedup_by_file(results); + assert_eq!(deduped.len(), 2); + // Sorted by score descending + assert_eq!(deduped[0].file_path, "a.md"); + assert!((deduped[0].score - 0.9).abs() < 1e-10); + assert_eq!(deduped[0].snippet, "high"); + assert_eq!(deduped[1].file_path, "b.md"); + } + + #[test] + fn test_dedup_by_file_empty() { + let deduped = dedup_by_file(vec![]); + assert!(deduped.is_empty()); + } + + #[test] + fn test_merge_seeds_deduplicates() { + let semantic = vec![RankedResult { + file_path: "shared.md".to_string(), + file_id: 1, + score: 0.8, + heading: None, + snippet: "sem".to_string(), + docid: None, + }]; + let fts = vec![ + RankedResult { + file_path: "shared.md".to_string(), + file_id: 1, + score: 0.9, + heading: None, + snippet: "fts".to_string(), + docid: None, + }, + RankedResult { + file_path: "fts_only.md".to_string(), + file_id: 2, + score: 0.6, + heading: None, + snippet: "fts only".to_string(), + docid: None, + }, + ]; + let merged = merge_seeds(&semantic, &fts); + assert_eq!(merged.len(), 2); + // "shared.md" should have the FTS score (0.9 > 0.8) + let shared = merged.iter().find(|r| r.file_path == "shared.md").unwrap(); + assert!((shared.score - 0.9).abs() < 1e-10); + } } diff --git a/src/serve.rs b/src/serve.rs index 81c9cd0..48ba85b 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -13,7 +13,7 @@ use tokio::sync::Mutex; use crate::config::Config; use crate::context::{self, ContextParams}; -use crate::embedder::Embedder; +use crate::llm::{EmbedModel, OrchestratorModel, RerankModel}; use crate::profile::VaultProfile; use crate::search; use crate::store::Store; @@ -127,10 +127,16 @@ pub struct UnarchiveParams { #[derive(Clone)] pub struct EngraphServer { store: Arc>, - embedder: Arc>, + embedder: Arc>>, vault_path: Arc, profile: Arc>, tool_router: ToolRouter, + /// Query expansion orchestrator (None when intelligence is disabled or failed to load). + #[allow(dead_code)] + orchestrator: Option>>>, + /// Result reranker (None when intelligence is disabled or failed to load). + #[allow(dead_code)] + reranker: Option>>>, } fn mcp_err(e: &anyhow::Error) -> McpError { @@ -162,7 +168,7 @@ impl EngraphServer { let top_n = params.0.top_n.unwrap_or(10); let store = self.store.lock().await; let mut embedder = self.embedder.lock().await; - let output = search::search_internal(¶ms.0.query, top_n, &store, &mut embedder) + let output = search::search_internal(¶ms.0.query, top_n, &store, &mut *embedder) .map_err(|e| mcp_err(&e))?; to_json_result(&output.results) } @@ -268,7 +274,7 @@ impl EngraphServer { profile: self.profile.as_ref().as_ref(), }; let bundle = - context::context_topic_with_search(&ctx, ¶ms.0.topic, budget, &mut embedder) + context::context_topic_with_search(&ctx, ¶ms.0.topic, budget, &mut *embedder) .map_err(|e| mcp_err(&e))?; to_json_result(&bundle) } @@ -291,7 +297,7 @@ impl EngraphServer { let result = crate::writer::create_note( input, &store, - &mut embedder, + &mut *embedder, &self.vault_path, self.profile.as_ref().as_ref(), ) @@ -311,7 +317,7 @@ impl EngraphServer { content: params.0.content, modified_by: "claude-code".into(), }; - let result = crate::writer::append_to_note(input, &store, &mut embedder, &self.vault_path) + let result = crate::writer::append_to_note(input, &store, &mut *embedder, &self.vault_path) .map_err(|e| mcp_err(&e))?; to_json_result(&result) } @@ -382,7 +388,7 @@ impl EngraphServer { let store = self.store.lock().await; let mut embedder = self.embedder.lock().await; let result = - crate::writer::unarchive_note(¶ms.0.file, &store, &mut embedder, &self.vault_path) + crate::writer::unarchive_note(¶ms.0.file, &store, &mut *embedder, &self.vault_path) .map_err(|e| mcp_err(&e))?; to_json_result(&result) } @@ -409,7 +415,8 @@ pub async fn run_serve(data_dir: &Path) -> Result<()> { let models_dir = data_dir.join("models"); let store = Store::open(&db_path)?; - let embedder = Embedder::new(&models_dir)?; + let config = Config::load()?; + let embedder = crate::llm::CandleEmbed::new(&models_dir, &config)?; let vault_path_str = store.get_meta("vault_path")?.ok_or_else(|| { anyhow::anyhow!("No vault path in index. Run 'engraph index ' first.") @@ -431,13 +438,44 @@ pub async fn run_serve(data_dir: &Path) -> Result<()> { let profile = Config::load_vault_profile().ok().flatten(); + // Load intelligence models if enabled + let orchestrator: Option>>> = + if config.intelligence_enabled() { + match crate::llm::CandleOrchestrator::new(&models_dir, &config) { + Ok(orch) => Some(Arc::new(Mutex::new( + Box::new(orch) as Box + ))), + Err(e) => { + tracing::warn!("failed to load orchestrator: {e}, intelligence disabled"); + None + } + } + } else { + None + }; + + let reranker: Option>>> = if config.intelligence_enabled() + { + match crate::llm::CandleRerank::new(&models_dir, &config) { + Ok(rerank) => Some(Arc::new(Mutex::new( + Box::new(rerank) as Box + ))), + Err(e) => { + tracing::warn!("failed to load reranker: {e}, reranking disabled"); + None + } + } + } else { + None + }; + let store_arc = Arc::new(Mutex::new(store)); - let embedder_arc = Arc::new(Mutex::new(embedder)); + let embedder_arc: Arc>> = + Arc::new(Mutex::new(Box::new(embedder) as Box)); let vault_path_arc = Arc::new(vault_path); let profile_arc = Arc::new(profile); // Start file watcher for real-time index updates - let config = Config::load()?; let mut exclude = config.exclude.clone(); if let Some(ref prof) = *profile_arc && let Some(ref archive) = prof.structure.folders.archive @@ -462,6 +500,8 @@ pub async fn run_serve(data_dir: &Path) -> Result<()> { vault_path: vault_path_arc, profile: profile_arc, tool_router: EngraphServer::tool_router(), + orchestrator, + reranker, }; eprintln!("engraph MCP server starting..."); diff --git a/src/store.rs b/src/store.rs index 9425f66..4485bf3 100644 --- a/src/store.rs +++ b/src/store.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use rusqlite::{Connection, params}; +use rusqlite::{Connection, OptionalExtension, params}; use std::collections::HashSet; use std::path::Path; @@ -102,6 +102,13 @@ CREATE TABLE IF NOT EXISTS tombstones ( vector_id INTEGER UNIQUE NOT NULL, created_at TEXT NOT NULL ); + +CREATE TABLE IF NOT EXISTS llm_cache ( + query_hash TEXT PRIMARY KEY, + result TEXT NOT NULL, + model TEXT NOT NULL, + created_at TEXT NOT NULL +); "#; pub struct Store { @@ -134,7 +141,7 @@ impl Store { .context("failed to initialize schema")?; self.migrate()?; self.ensure_fts_table()?; - crate::vecstore::init_vec_table(&self.conn)?; + crate::vecstore::init_vec_table(&self.conn, 256)?; self.migrate_vectors_to_vec0()?; Ok(()) } @@ -292,6 +299,29 @@ impl Store { } } + // ── LLM Cache ─────────────────────────────────────────────── + + /// Cache an LLM orchestration result by query hash. + pub fn set_llm_cache(&self, query_hash: &str, result: &str, model: &str) -> Result<()> { + self.conn.execute( + "INSERT OR REPLACE INTO llm_cache (query_hash, result, model, created_at) + VALUES (?1, ?2, ?3, datetime('now'))", + params![query_hash, result, model], + )?; + Ok(()) + } + + /// Retrieve a cached LLM result by query hash. + pub fn get_llm_cache(&self, query_hash: &str) -> Result> { + let mut stmt = self + .conn + .prepare("SELECT result FROM llm_cache WHERE query_hash = ?1")?; + let result = stmt + .query_row(params![query_hash], |row| row.get::<_, String>(0)) + .optional()?; + Ok(result) + } + // ── Files ─────────────────────────────────────────────────── pub fn insert_file( @@ -1124,6 +1154,25 @@ impl Store { crate::vecstore::clear_vec(&self.conn) } + /// Check if the stored embedding dimension differs from the model's dimension. + pub fn has_dimension_mismatch(&self, model_dim: usize) -> Result { + match self.get_meta("embedding_dim")? { + Some(stored) => { + let stored_dim: usize = stored.parse().unwrap_or(0); + Ok(stored_dim != model_dim) + } + None => Ok(false), // First run, no stored dimension + } + } + + /// Drop the vec table and all chunk records. Used during dimension migration. + pub fn reset_for_reindex(&self, new_dim: usize) -> Result<()> { + self.conn.execute("DROP TABLE IF EXISTS chunks_vec", [])?; + crate::vecstore::init_vec_table(&self.conn, new_dim)?; + self.conn.execute("DELETE FROM chunks", [])?; + Ok(()) + } + // ── Transactions ──────────────────────────────────────────── pub fn begin_transaction(&self) -> Result<()> { @@ -2209,7 +2258,7 @@ mod tests { #[test] fn test_store_vec_roundtrip() { let store = Store::open_memory().unwrap(); - let vector: Vec = (0..384).map(|i| (i as f32) / 384.0).collect(); + let vector: Vec = (0..256).map(|i| (i as f32) / 256.0).collect(); store.insert_vec(0, &vector).unwrap(); let results = store @@ -2227,7 +2276,7 @@ mod tests { let file_id = store .insert_file("test.md", "hash123", 0, &[], "abc123", None) .unwrap(); - let vector: Vec = (0..384).map(|i| (i as f32) / 384.0).collect(); + let vector: Vec = (0..256).map(|i| (i as f32) / 256.0).collect(); store .insert_chunk_with_vector(file_id, "heading", "snippet", 0, 100, &vector) .unwrap(); @@ -2524,4 +2573,57 @@ mod tests { assert_eq!(corrections[0].file_path, "notes/second.md"); assert_eq!(corrections[1].file_path, "notes/first.md"); } + + // ── LLM cache tests ──────────────────────────────────────── + + #[test] + fn test_llm_cache_roundtrip() { + let store = Store::open_memory().unwrap(); + store + .set_llm_cache("abc123", r#"{"intent":"exact"}"#, "qwen3-0.6B") + .unwrap(); + let result = store.get_llm_cache("abc123").unwrap(); + assert_eq!(result, Some(r#"{"intent":"exact"}"#.to_string())); + } + + #[test] + fn test_llm_cache_miss() { + let store = Store::open_memory().unwrap(); + let result = store.get_llm_cache("nonexistent").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_llm_cache_overwrite() { + let store = Store::open_memory().unwrap(); + store.set_llm_cache("key1", "old", "model1").unwrap(); + store.set_llm_cache("key1", "new", "model1").unwrap(); + let result = store.get_llm_cache("key1").unwrap(); + assert_eq!(result, Some("new".to_string())); + } + + #[test] + fn test_embedding_dim_meta() { + let store = Store::open_memory().unwrap(); + assert!(store.get_meta("embedding_dim").unwrap().is_none()); + store.set_meta("embedding_dim", "256").unwrap(); + assert_eq!( + store.get_meta("embedding_dim").unwrap(), + Some("256".to_string()) + ); + } + + #[test] + fn test_detect_dimension_mismatch() { + let store = Store::open_memory().unwrap(); + store.set_meta("embedding_dim", "384").unwrap(); + assert!(store.has_dimension_mismatch(256).unwrap()); + assert!(!store.has_dimension_mismatch(384).unwrap()); + } + + #[test] + fn test_no_mismatch_when_unset() { + let store = Store::open_memory().unwrap(); + assert!(!store.has_dimension_mismatch(256).unwrap()); + } } diff --git a/src/vecstore.rs b/src/vecstore.rs index a47065c..7963f10 100644 --- a/src/vecstore.rs +++ b/src/vecstore.rs @@ -18,11 +18,13 @@ pub fn init_sqlite_vec() { } /// Create the `chunks_vec` virtual table if it doesn't already exist. -pub fn init_vec_table(conn: &Connection) -> Result<()> { +pub fn init_vec_table(conn: &Connection, dim: usize) -> Result<()> { conn.execute( - "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_vec USING vec0( - embedding float[384] distance_metric=cosine - )", + &format!( + "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_vec USING vec0( + embedding float[{dim}] distance_metric=cosine + )" + ), [], )?; Ok(()) @@ -103,13 +105,13 @@ mod tests { fn setup_conn() -> Connection { init_sqlite_vec(); let conn = Connection::open_in_memory().unwrap(); - init_vec_table(&conn).unwrap(); + init_vec_table(&conn, 384).unwrap(); conn } - fn random_vector(seed: u64) -> Vec { + fn random_vector(seed: u64, dim: usize) -> Vec { let mut state = seed.wrapping_mul(6364136223846793005).wrapping_add(1); - (0..384) + (0..dim) .map(|_| { state = state.wrapping_mul(6364136223846793005).wrapping_add(1); ((state >> 33) as f32) / (u32::MAX as f32) * 2.0 - 1.0 @@ -134,7 +136,7 @@ mod tests { #[test] fn test_insert_and_search() { let conn = setup_conn(); - let vectors: Vec> = (0..10).map(random_vector).collect(); + let vectors: Vec> = (0..10).map(|i| random_vector(i, 384)).collect(); for (i, v) in vectors.iter().enumerate() { insert_vec(&conn, i as u64, v).unwrap(); @@ -156,7 +158,7 @@ mod tests { #[test] fn test_search_with_tombstones() { let conn = setup_conn(); - let vectors: Vec> = (0..5).map(|i| random_vector(i + 100)).collect(); + let vectors: Vec> = (0..5).map(|i| random_vector(i + 100, 384)).collect(); for (i, v) in vectors.iter().enumerate() { insert_vec(&conn, i as u64, v).unwrap(); @@ -174,7 +176,7 @@ mod tests { #[test] fn test_delete_vec() { let conn = setup_conn(); - insert_vec(&conn, 1, &random_vector(42)).unwrap(); + insert_vec(&conn, 1, &random_vector(42, 384)).unwrap(); let count_before: i64 = conn .query_row("SELECT count(*) FROM chunks_vec", [], |row| row.get(0)) @@ -192,8 +194,31 @@ mod tests { #[test] fn test_empty_search() { let conn = setup_conn(); - let query = random_vector(999); + let query = random_vector(999, 384); let results = search_vec(&conn, &query, 5, &HashSet::new()).unwrap(); assert!(results.is_empty(), "empty table should return no results"); } + + #[test] + fn test_init_vec_table_custom_dim() { + init_sqlite_vec(); + let conn = Connection::open_in_memory().unwrap(); + init_vec_table(&conn, 256).unwrap(); + + let count: i64 = conn + .query_row( + "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='chunks_vec'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 1); + + // Insert and search with 256-dim vector + let vec256: Vec = (0..256).map(|i| (i as f32) / 256.0).collect(); + insert_vec(&conn, 1, &vec256).unwrap(); + let results = search_vec(&conn, &vec256, 1, &HashSet::new()).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].0, 1); + } } diff --git a/src/watcher.rs b/src/watcher.rs index 915f3e1..298cc91 100644 --- a/src/watcher.rs +++ b/src/watcher.rs @@ -10,8 +10,8 @@ use tokio::sync::mpsc; use tokio::sync::oneshot; use crate::config::Config; -use crate::embedder::Embedder; use crate::indexer; +use crate::llm::EmbedModel; use crate::placement; use crate::profile::VaultProfile; use crate::store::Store; @@ -22,7 +22,7 @@ use crate::store::Store; /// real-time file changes. pub fn start_watcher( store: Arc>, - embedder: Arc>, + embedder: Arc>>, vault_path: Arc, profile: Arc>, config: Config, @@ -49,7 +49,7 @@ pub fn start_watcher( &vault_clone, &config_clone, &store_lock, - &mut embedder_lock, + &mut *embedder_lock, false, ) { tracing::warn!("Startup reconciliation failed: {:#}", e); @@ -278,7 +278,7 @@ fn detect_moves(events: &mut Vec, store: &Store, vault_path: &Path) pub async fn run_consumer( mut rx: mpsc::Receiver>, store: Arc>, - embedder: Arc>, + embedder: Arc>>, vault_path: Arc, _profile: Arc>, config: Config, @@ -333,7 +333,7 @@ pub async fn run_consumer( &content, &content_hash, &store_guard, - &mut embedder_guard, + &mut *embedder_guard, &vault_path, &config, ) { @@ -594,7 +594,7 @@ pub async fn run_consumer( &vault_path, &config, &store_guard, - &mut embedder_guard, + &mut *embedder_guard, false, ) { Ok(result) => { diff --git a/src/writer.rs b/src/writer.rs index f431cc2..556df4c 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -7,9 +7,9 @@ use time::OffsetDateTime; use crate::chunker::{chunk_markdown, split_oversized_chunks}; use crate::docid::generate_docid; -use crate::embedder::Embedder; use crate::indexer::build_edges_for_file; use crate::links; +use crate::llm::EmbedModel; use crate::placement::{self, PlacementHints}; use crate::profile::VaultProfile; use crate::store::Store; @@ -197,7 +197,7 @@ fn file_mtime(path: &Path) -> Result { type ChunkData = (String, String, Vec, i64); // (heading, snippet, vector, token_count) /// Chunk content, embed, and return pre-computed data ready for store insertion. -fn precompute_chunks(content: &str, embedder: &mut Embedder) -> Result> { +fn precompute_chunks(content: &str, embedder: &mut impl EmbedModel) -> Result> { let parsed = chunk_markdown(content); let chunks = split_oversized_chunks(parsed.chunks, &|s| s.split_whitespace().count(), 512, 50); @@ -258,7 +258,7 @@ pub fn cleanup_temp_files(vault_path: &Path) -> Result { pub fn create_note( input: CreateNoteInput, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, vault_path: &Path, profile: Option<&VaultProfile>, ) -> Result { @@ -494,7 +494,7 @@ pub fn create_note( pub fn append_to_note( input: AppendInput, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, vault_path: &Path, ) -> Result { // Step 1: Resolve file @@ -870,7 +870,7 @@ pub fn archive_note( pub fn unarchive_note( file: &str, store: &Store, - embedder: &mut Embedder, + embedder: &mut impl EmbedModel, vault_path: &Path, ) -> Result { // Resolve — the file may not be in the index (archived notes are excluded).