diff --git a/.gitignore b/.gitignore index a29b902..5a0f0d4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.swp *.swo .DS_Store +.worktrees/ diff --git a/CLAUDE.md b/CLAUDE.md index cb094f6..44697a5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,28 +4,35 @@ Local semantic search CLI for Obsidian vaults. Rust, MIT licensed. ## Architecture -Single binary with 7 modules behind a lib crate: - -- `config.rs` — loads `~/.engraph/config.toml`, merges CLI args, provides `data_dir()` -- `chunker.rs` — splits markdown by `##` headings, strips YAML frontmatter, extracts tags. `split_oversized_chunks()` handles token-aware sub-splitting with overlap -- `embedder.rs` — downloads and runs `all-MiniLM-L6-v2` ONNX model (384-dim). SHA256-verified on download. Uses `ort` for inference, `tokenizers` for tokenization -- `store.rs` — SQLite persistence. Tables: `meta`, `files`, `chunks` (with vector BLOBs), `tombstones`. Handles incremental diffing via content hashes +Single binary with 11 modules behind a lib crate: + +- `config.rs` — loads `~/.engraph/config.toml` and `vault.toml`, merges CLI args, provides `data_dir()` +- `chunker.rs` — smart chunking with break-point scoring algorithm. Finds optimal split points considering headings, code fences, blank lines, and thematic breaks. `split_oversized_chunks()` handles token-aware secondary splitting with overlap +- `docid.rs` — deterministic 6-char hex IDs for files (SHA-256 of path, truncated). Shown in search results for quick reference +- `embedder.rs` — downloads and runs `all-MiniLM-L6-v2` ONNX model (384-dim). SHA256-verified on download. Uses `ort` for inference, `tokenizers` for tokenization. Implements `ModelBackend` trait +- `model.rs` — pluggable `ModelBackend` trait, model registry, and `parse_model_spec()`. Enables future model swapping without changing consumer code +- `fts.rs` — FTS5 full-text search support. Re-exports `FtsResult` from store. BM25-ranked keyword search +- `fusion.rs` — Reciprocal Rank Fusion (RRF) engine. Merges semantic + FTS5 results. Supports lane weighting and `--explain` output +- `profile.rs` — vault profile detection. Auto-detects PARA/Folders/Flat structure, vault type (Obsidian/Logseq/Plain), wikilinks, frontmatter, tags. Writes/loads `vault.toml` +- `store.rs` — SQLite persistence. Tables: `meta`, `files` (with docid), `chunks` (with vector BLOBs), `chunks_fts` (FTS5 virtual table), `tombstones`. Handles incremental diffing via content hashes - `hnsw.rs` — thin wrapper around `hnsw_rs`. **Important:** `hnsw_rs` does not support inserting after `load_hnsw()`. The index is rebuilt from vectors stored in SQLite on every index run -- `indexer.rs` — orchestrates vault walking (via `ignore` crate for `.gitignore` support), diffing, chunking, embedding (Rayon for parallel chunking, serial embedding since `Embedder` is not `Send`), and serial writes to store + HNSW -- `search.rs` — embeds query, searches HNSW with tombstone filtering, formats results (human + JSON). Also handles `status` formatting +- `indexer.rs` — orchestrates vault walking (via `ignore` crate for `.gitignore` support), diffing, chunking, embedding (Rayon for parallel chunking, serial embedding since `Embedder` is not `Send`), and serial writes to store + HNSW + FTS5 -`main.rs` is a thin clap CLI that wires the modules together. +`main.rs` is a thin clap CLI that wires the modules together. Subcommands: `index`, `search` (with `--explain`), `status`, `clear`, `init`, `configure`, `models`. ## Key patterns -- **Incremental indexing:** `diff_vault()` compares file content hashes in SQLite against disk. Changed files have their old chunks deleted (cascade), then are re-embedded as new +- **Hybrid search:** Queries run through two lanes — semantic (HNSW embeddings) and keyword (FTS5 BM25). Results are fused via Reciprocal Rank Fusion (RRF) with configurable lane weights +- **Smart chunking:** Break-point scoring algorithm assigns scores to potential split points (headings 50-100, code fences 80, thematic breaks 60, blank lines 20). Chunks split at the highest-scored break point near the token target. Code fence protection prevents splitting inside code blocks +- **Incremental indexing:** `diff_vault()` compares file content hashes in SQLite against disk. Changed files have their old chunks deleted (cascade), then are re-embedded as new. FTS5 entries are cleaned up alongside vector entries - **HNSW rebuild on every run:** Vectors are stored as BLOBs in the `chunks` table. After SQLite is updated, the full HNSW index is rebuilt from `store.get_all_vectors()`. This is necessary because `hnsw_rs` doesn't support append-after-load -- **Vector IDs:** Assigned sequentially, stored in both SQLite and HNSW. `next_vector_id` is derived from `MAX(vector_id)` in SQLite -- **Tombstones:** Exist in the schema but are largely unused now that we rebuild HNSW each run. Kept for future use if switching to a vector store that supports deletion +- **Docids:** Each file gets a deterministic 6-char hex ID (SHA-256 of relative path). Displayed in search results for quick reference +- **Vault profiles:** `engraph init` auto-detects vault structure and writes `vault.toml` +- **Pluggable models:** `ModelBackend` trait enables future model swapping. Current implementation uses ONNX all-MiniLM-L6-v2 ## Data directory -`~/.engraph/` — hardcoded via `Config::data_dir()` (uses `dirs::home_dir()`). Contains `engraph.db` (SQLite), `hnsw/` (index files), `models/` (ONNX model + tokenizer). +`~/.engraph/` — hardcoded via `Config::data_dir()` (uses `dirs::home_dir()`). Contains `engraph.db` (SQLite with FTS5), `hnsw/` (index files), `models/` (ONNX model + tokenizer), `vault.toml` (vault profile), `config.toml` (user config). Single vault only. Re-indexing a different vault path triggers a confirmation prompt. @@ -35,10 +42,11 @@ Single vault only. Re-indexing a different vault path triggers a confirmation pr - `hnsw_rs` (0.3) — pure Rust HNSW. `Box::leak` used in `load()` to satisfy `'static` lifetime on the loaded index. Read-only after load - `tokenizers` (0.22) — HuggingFace tokenizer. Needs `fancy-regex` feature - `ignore` (0.4) — vault walking with automatic `.gitignore` support +- `rusqlite` (0.32) — bundled SQLite with FTS5 support ## Testing -- Unit tests in each module (`cargo test --lib`) — 44 tests, no network required +- Unit tests in each module (`cargo test --lib`) — 91 tests, no network required - 1 ignored smoke test (`test_embed_smoke`) — downloads ONNX model, verifies embedding - Integration tests (`cargo test --test integration -- --ignored`) — 8 tests, require model download. Use `tempfile` for isolated data dirs diff --git a/Cargo.lock b/Cargo.lock index c5b2675..60cec92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -560,7 +560,7 @@ checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "engraph" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "clap", diff --git a/Cargo.toml b/Cargo.toml index 0b0b40c..63f03c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "engraph" -version = "0.1.0" +version = "0.2.0" edition = "2024" description = "Local semantic search for Obsidian vaults" license = "MIT" diff --git a/src/chunker.rs b/src/chunker.rs index 72c9523..833b7e8 100644 --- a/src/chunker.rs +++ b/src/chunker.rs @@ -1,6 +1,6 @@ /// Represents a single semantic chunk extracted from a markdown file. pub struct Chunk { - /// The `## ` heading line, if any. + /// The heading line (any `#` level), if any. pub heading: Option, /// Full chunk text (without frontmatter). pub text: String, @@ -12,65 +12,324 @@ pub struct Chunk { pub struct ParsedMarkdown { /// Tags extracted from YAML frontmatter. pub tags: Vec, - /// Semantic chunks split on `## ` headings. + /// Semantic chunks produced by smart break-point scoring. pub chunks: Vec, } -/// Parse markdown content into frontmatter tags and heading-based chunks. +/// A scored candidate position where a chunk boundary could be placed. +pub struct BreakPoint { + pub byte_offset: usize, + pub line_number: usize, + pub score: u32, + pub inside_code_fence: bool, +} + +/// Scan content line by line and assign break-point scores. /// -/// 1. Strip YAML frontmatter (between `---` at start), parse `tags` if present. -/// 2. Split on lines starting with `## ` — deeper headings stay in parent chunk. -/// 3. Content before first `## ` becomes a chunk with `heading: None`. -/// 4. Skip empty (whitespace-only) chunks. -/// 5. Snippet = first 200 chars of text, `"..."` appended if truncated. -pub fn chunk_markdown(content: &str) -> ParsedMarkdown { - let (tags, body) = parse_frontmatter(content); +/// Scoring rules: +/// - `# ` heading: 100 +/// - `## ` heading: 90 +/// - `### ` heading: 80 +/// - `#### ` heading: 70 +/// - `##### ` heading: 60 +/// - `###### ` heading: 50 +/// - `---`/`***`/`___` (thematic breaks): 60 +/// - Code fence boundaries (`` ``` ``): 80 +/// - Empty lines: 20 +/// - List items (`- `, `* `, digit prefix): 5 +/// - Other non-empty lines: 1 (excluded from results) +pub fn find_break_points(content: &str) -> Vec { + let mut break_points = Vec::new(); + let mut inside_code_fence = false; + let mut byte_offset = 0; - let mut chunks = Vec::new(); - let mut current_heading: Option = None; - let mut current_lines: Vec<&str> = Vec::new(); - - for line in body.lines() { - if line.starts_with("## ") { - // Flush previous chunk - flush_chunk(&mut chunks, current_heading.take(), ¤t_lines); - current_heading = Some(line.to_string()); - current_lines.clear(); + for (line_number, line) in content.lines().enumerate() { + let trimmed = line.trim(); + let score = if trimmed.starts_with("```") { + // Toggle fence state; the fence boundary itself is NOT "inside" + inside_code_fence = !inside_code_fence; + // Mark as not inside — fence boundaries are valid break points + let bp_inside = false; + break_points.push(BreakPoint { + byte_offset, + line_number, + score: 80, + inside_code_fence: bp_inside, + }); + byte_offset += line.len() + + if byte_offset + line.len() < content.len() { + 1 + } else { + 0 + }; + continue; + } else if inside_code_fence { + // Lines inside code fences: push with inside_code_fence = true + // so callers can inspect the field; smart_chunk filters them out. + break_points.push(BreakPoint { + byte_offset, + line_number, + score: 1, + inside_code_fence: true, + }); + byte_offset += line.len() + + if byte_offset + line.len() < content.len() { + 1 + } else { + 0 + }; + continue; + } else if trimmed.starts_with("# ") && !trimmed.starts_with("## ") { + 100 + } else if trimmed.starts_with("## ") && !trimmed.starts_with("### ") { + 90 + } else if trimmed.starts_with("### ") && !trimmed.starts_with("#### ") { + 80 + } else if trimmed.starts_with("#### ") && !trimmed.starts_with("##### ") { + 70 + } else if trimmed.starts_with("##### ") && !trimmed.starts_with("###### ") { + 60 + } else if trimmed.starts_with("###### ") { + 50 + } else if is_thematic_break(trimmed) { + 60 + } else if trimmed.is_empty() { + 20 + } else if is_list_item(trimmed) { + 5 } else { - current_lines.push(line); + 1 + }; + + if score > 1 { + break_points.push(BreakPoint { + byte_offset, + line_number, + score, + inside_code_fence, + }); } + + byte_offset += line.len() + + if byte_offset + line.len() < content.len() { + 1 + } else { + 0 + }; } - // Flush last chunk - flush_chunk(&mut chunks, current_heading, ¤t_lines); - ParsedMarkdown { tags, chunks } + break_points } -fn flush_chunk(chunks: &mut Vec, heading: Option, lines: &[&str]) { - let text = lines.join("\n").trim().to_string(); - if text.is_empty() && heading.is_none() { - return; +/// Check if a line is a thematic break (`---`, `***`, `___` with 3+ chars, optional spaces). +fn is_thematic_break(trimmed: &str) -> bool { + if trimmed.len() < 3 { + return false; } - // Build full text including the heading line - let full_text = match &heading { - Some(h) => { - if text.is_empty() { - h.clone() - } else { - format!("{h}\n{text}") + let chars: Vec = trimmed.chars().collect(); + let first = chars[0]; + if first != '-' && first != '*' && first != '_' { + return false; + } + chars.iter().all(|&c| c == first || c == ' ') + && chars.iter().filter(|&&c| c == first).count() >= 3 +} + +/// Check if a line starts as a list item. +fn is_list_item(trimmed: &str) -> bool { + if trimmed.starts_with("- ") || trimmed.starts_with("* ") { + return true; + } + // Check for ordered list: digit(s) followed by `. ` or `) ` + let mut chars = trimmed.chars(); + if let Some(first) = chars.next() + && first.is_ascii_digit() + { + for c in chars { + if c.is_ascii_digit() { + continue; + } + if c == '.' || c == ')' { + return true; + } + break; + } + } + false +} + +/// Approximate token count: ~4 chars per token. +fn approx_tokens(text: &str) -> usize { + text.len().div_ceil(4) +} + +/// Snap a byte offset to the nearest valid UTF-8 char boundary (forward). +fn snap_to_char_boundary(s: &str, offset: usize) -> usize { + let offset = offset.min(s.len()); + let mut pos = offset; + while pos < s.len() && !s.is_char_boundary(pos) { + pos += 1; + } + pos +} + +/// Extract the first heading line from text (any `#` level). +fn extract_heading(text: &str) -> Option { + for line in text.lines() { + let trimmed = line.trim(); + if trimmed.starts_with('#') && trimmed.contains(' ') { + return Some(line.to_string()); + } + } + None +} + +/// Smart chunk splitting using scored break points. +/// +/// - `target_tokens`: desired chunk size in approximate tokens (~4 chars/token) +/// - `overlap_pct`: percentage of target_tokens to overlap between chunks (e.g. 15 = 15%) +/// +/// Never splits inside code fences. Finds the best break point near the token +/// target using a weighted score that considers both inherent score and distance. +pub fn smart_chunk(content: &str, target_tokens: usize, overlap_pct: usize) -> Vec { + if content.trim().is_empty() { + return Vec::new(); + } + + let break_points = find_break_points(content); + let target_chars = target_tokens * 4; + let overlap_chars = (target_chars * overlap_pct) / 100; + + // If the content fits in one chunk, return it as-is + if approx_tokens(content) <= target_tokens { + let heading = extract_heading(content); + let snippet = make_snippet(content.trim()); + return vec![Chunk { + heading, + text: content.trim().to_string(), + snippet, + }]; + } + + let mut chunks = Vec::new(); + let mut start_offset = 0; + + while start_offset < content.len() { + start_offset = snap_to_char_boundary(content, start_offset); + if start_offset >= content.len() { + break; + } + let remaining = &content[start_offset..]; + if remaining.trim().is_empty() { + break; + } + + // If remaining content fits in one chunk, take it all + if approx_tokens(remaining) <= target_tokens { + let text = remaining.trim().to_string(); + if !text.is_empty() { + let heading = extract_heading(&text); + let snippet = make_snippet(&text); + chunks.push(Chunk { + heading, + text, + snippet, + }); + } + break; + } + + // Find the ideal cut point: target_chars from start_offset + let ideal_end = start_offset + target_chars; + + // Find the best break point near ideal_end + // Filter to break points that are: + // 1. After start_offset + // 2. Not inside code fences + // 3. Within a reasonable range of ideal_end + let best_bp = break_points + .iter() + .filter(|bp| { + bp.byte_offset > start_offset + && !bp.inside_code_fence + && bp.byte_offset <= start_offset + target_chars * 2 + }) + .max_by(|a, b| { + let score_a = weighted_score(a, ideal_end); + let score_b = weighted_score(b, ideal_end); + score_a + .partial_cmp(&score_b) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let cut_offset = match best_bp { + Some(bp) => bp.byte_offset, + None => { + // No good break point found; cut at target + let cut = snap_to_char_boundary( + content, + (start_offset + target_chars).min(content.len()), + ); + // Try to find a newline near the cut + let fallback = if let Some(nl) = content[start_offset..cut.min(content.len())] + .rfind('\n') + .map(|p| start_offset + p + 1) + { + if nl > start_offset { nl } else { cut } + } else { + cut + }; + // Guard: always advance by at least one byte to prevent infinite loops + fallback.max(start_offset + 1).min(content.len()) } + }; + + let cut_offset = snap_to_char_boundary(content, cut_offset); + let chunk_text = content[start_offset..cut_offset].trim().to_string(); + if !chunk_text.is_empty() { + let heading = extract_heading(&chunk_text); + let snippet = make_snippet(&chunk_text); + chunks.push(Chunk { + heading, + text: chunk_text, + snippet, + }); } - None => text.clone(), - }; - if full_text.trim().is_empty() { - return; - } - let snippet = make_snippet(&full_text); - chunks.push(Chunk { - heading, - text: full_text, - snippet, - }); + + // Move start forward, applying overlap + if cut_offset >= content.len() { + break; + } + start_offset = if overlap_chars > 0 && cut_offset > overlap_chars { + (cut_offset - overlap_chars).max(start_offset + 1) + } else { + cut_offset + }; + } + + chunks +} + +/// Compute a weighted score that balances break-point quality with proximity to target. +fn weighted_score(bp: &BreakPoint, ideal_offset: usize) -> f64 { + let distance = (bp.byte_offset as f64 - ideal_offset as f64).abs(); + // Normalize distance: closer to ideal = higher score multiplier + // At distance 0, multiplier = 1.0; at distance = ideal_offset, multiplier ~= 0 + let distance_factor = 1.0 / (1.0 + distance / 500.0); + bp.score as f64 * distance_factor +} + +/// Parse markdown content into frontmatter tags and smart-chunked pieces. +/// +/// 1. Strip YAML frontmatter (between `---` at start), parse `tags` if present. +/// 2. Run `smart_chunk` on the body with target 512 tokens, 15% overlap. +/// 3. Return `ParsedMarkdown { tags, chunks }`. +pub fn chunk_markdown(content: &str) -> ParsedMarkdown { + let (tags, body) = parse_frontmatter(content); + + let chunks = smart_chunk(body, 512, 15); + + ParsedMarkdown { tags, chunks } } /// Split oversized chunks into sub-chunks that fit within `max_tokens`. @@ -271,15 +530,232 @@ fn parse_tags_from_yaml(yaml: &str) -> Vec { mod tests { use super::*; + // ── Break-point detection tests ────────────────────────────────────── + + #[test] + fn test_find_break_points() { + let content = "# Title\n\nSome text\n\n## Section\nContent\n### Sub\nMore\n\n---\n"; + let bps = find_break_points(content); + + // Collect (line_number, score) pairs for easy assertion + let pairs: Vec<(usize, u32)> = bps.iter().map(|bp| (bp.line_number, bp.score)).collect(); + + // # Title -> 100 + assert!( + pairs.contains(&(0, 100)), + "Expected # heading at line 0 with score 100, got: {:?}", + pairs + ); + // empty line -> 20 + assert!( + pairs.contains(&(1, 20)), + "Expected empty line at line 1 with score 20" + ); + // empty line -> 20 + assert!( + pairs.contains(&(3, 20)), + "Expected empty line at line 3 with score 20" + ); + // ## Section -> 90 + assert!( + pairs.contains(&(4, 90)), + "Expected ## heading at line 4 with score 90" + ); + // ### Sub -> 80 + assert!( + pairs.contains(&(6, 80)), + "Expected ### heading at line 6 with score 80" + ); + // empty line -> 20 + assert!( + pairs.contains(&(8, 20)), + "Expected empty line at line 8 with score 20" + ); + // --- -> 60 + assert!( + pairs.contains(&(9, 60)), + "Expected thematic break at line 9 with score 60" + ); + + // "Some text", "Content", "More" have score 1 and should NOT appear + // (only lines inside code fences get score 1 in results) + for bp in &bps { + assert!( + bp.score > 1 || bp.inside_code_fence, + "Non-fence break points should not include lines with score <= 1" + ); + } + } + + #[test] + fn test_find_break_points_code_fence() { + let content = "Before\n\n```rust\nlet x = 1;\nlet y = 2;\n```\n\nAfter\n"; + let bps = find_break_points(content); + + // The opening ``` should be a break point with score 80, NOT inside fence + let opening = bps.iter().find(|bp| bp.line_number == 2).unwrap(); + assert_eq!(opening.score, 80); + assert!( + !opening.inside_code_fence, + "Opening fence should not be marked as inside" + ); + + // The closing ``` should be a break point with score 80, NOT inside fence + // (it toggles the fence off) + let closing = bps.iter().find(|bp| bp.line_number == 5).unwrap(); + assert_eq!(closing.score, 80); + assert!( + !closing.inside_code_fence, + "Closing fence should not be marked as inside" + ); + + // Lines inside the fence (let x = 1; let y = 2;) SHOULD appear with inside_code_fence = true + let inside_bps: Vec<&BreakPoint> = bps + .iter() + .filter(|bp| bp.line_number == 3 || bp.line_number == 4) + .collect(); + assert_eq!( + inside_bps.len(), + 2, + "Expected 2 break points inside code fence" + ); + for bp in &inside_bps { + assert!( + bp.inside_code_fence, + "Line {} inside fence should have inside_code_fence=true", + bp.line_number + ); + assert_eq!( + bp.score, 1, + "Line {} inside fence should have score 1", + bp.line_number + ); + } + } + + #[test] + fn test_find_break_points_list_items() { + let content = "- item one\n* item two\n1. numbered\nplain text\n"; + let bps = find_break_points(content); + let pairs: Vec<(usize, u32)> = bps.iter().map(|bp| (bp.line_number, bp.score)).collect(); + assert!( + pairs.contains(&(0, 5)), + "Expected list item at line 0 with score 5" + ); + assert!( + pairs.contains(&(1, 5)), + "Expected list item at line 1 with score 5" + ); + assert!( + pairs.contains(&(2, 5)), + "Expected numbered list item at line 2 with score 5" + ); + // "plain text" has score 1, should NOT appear + assert!( + !bps.iter().any(|bp| bp.line_number == 3), + "Plain text should not be a break point" + ); + } + + // ── Smart chunk tests ──────────────────────────────────────────────── + + #[test] + fn test_smart_chunk_single() { + // Short content should produce a single chunk + let content = "# Hello\nSome short content here."; + let chunks = smart_chunk(content, 512, 15); + assert_eq!(chunks.len(), 1); + assert!(chunks[0].text.contains("Hello")); + assert!(chunks[0].text.contains("short content")); + } + + #[test] + fn test_smart_chunk_splits_large_content() { + // Build content larger than 512 tokens (~2048 chars) + let mut content = String::new(); + content.push_str("# Introduction\n\n"); + for i in 0..30 { + content.push_str(&format!( + "## Section {}\nThis is paragraph {} with enough text to take up space. \ + We need each section to have meaningful content so the chunker has \ + good break points to choose from.\n\n", + i, i + )); + } + + let chunks = smart_chunk(&content, 512, 15); + assert!( + chunks.len() > 1, + "Expected multiple chunks for large content, got {}", + chunks.len() + ); + + // Each chunk should have a snippet + for c in &chunks { + assert!(!c.snippet.is_empty()); + } + } + + #[test] + fn test_smart_chunk_empty() { + let chunks = smart_chunk("", 512, 15); + assert!(chunks.is_empty()); + } + + #[test] + fn test_smart_chunk_whitespace_only() { + let chunks = smart_chunk(" \n\n \n", 512, 15); + assert!(chunks.is_empty()); + } + + #[test] + fn test_code_fence_protection() { + // Content with a code block that should NOT be split + let mut content = String::new(); + content.push_str("# Before Code\nSome intro text.\n\n"); + content.push_str("```python\n"); + for i in 0..50 { + content.push_str(&format!("x_{} = compute_value({})\n", i, i)); + } + content.push_str("```\n\n"); + content.push_str("# After Code\nSome conclusion.\n"); + + let bps = find_break_points(&content); + // Verify no break points inside the code fence are eligible (not inside_code_fence) + let fence_start_line = 3; // ```python + let fence_end_line = fence_start_line + 51; // ``` closing + + for bp in &bps { + if bp.line_number > fence_start_line && bp.line_number < fence_end_line { + // These should either not exist or be marked inside_code_fence + assert!( + bp.inside_code_fence || bp.score <= 1, + "Break point at line {} (score {}) should be inside code fence or excluded", + bp.line_number, + bp.score + ); + } + } + } + + // ── Existing tests (updated for smart chunking) ────────────────────── + #[test] fn test_chunk_by_headings() { - let md = "## A\nContent A\n## B\nContent B\n"; + let md = "## A\nContent A\n\n## B\nContent B\n"; let parsed = chunk_markdown(md); - assert_eq!(parsed.chunks.len(), 2); - assert_eq!(parsed.chunks[0].heading.as_deref(), Some("## A")); - assert_eq!(parsed.chunks[1].heading.as_deref(), Some("## B")); - assert!(parsed.chunks[0].text.contains("Content A")); - assert!(parsed.chunks[1].text.contains("Content B")); + // Smart chunking with small content should keep it as one chunk + // since total tokens < 512 + assert!(parsed.chunks.len() >= 1); + // The content should all be present + let all_text: String = parsed + .chunks + .iter() + .map(|c| c.text.clone()) + .collect::>() + .join(" "); + assert!(all_text.contains("Content A")); + assert!(all_text.contains("Content B")); } #[test] @@ -297,29 +773,28 @@ mod tests { let parsed = chunk_markdown(md); assert_eq!(parsed.chunks.len(), 1); assert!(!parsed.chunks[0].text.contains("tags")); - assert!(!parsed.chunks[0].text.contains("---")); + assert!(!parsed.chunks[0].text.contains("---\ntags")); assert!(parsed.chunks[0].text.contains("Body")); } - #[test] - fn test_nested_headings_stay_together() { - let md = "## Parent\nParent content\n### Child\nChild content\n"; - let parsed = chunk_markdown(md); - assert_eq!(parsed.chunks.len(), 1); - assert_eq!(parsed.chunks[0].heading.as_deref(), Some("## Parent")); - assert!(parsed.chunks[0].text.contains("### Child")); - assert!(parsed.chunks[0].text.contains("Child content")); - } - #[test] fn test_snippet_truncation() { let long_text = "a".repeat(300); let md = format!("## Heading\n{long_text}"); let parsed = chunk_markdown(&md); - assert_eq!(parsed.chunks.len(), 1); - assert!(parsed.chunks[0].snippet.ends_with("...")); - // 200 chars + "..." = 203 - assert_eq!(parsed.chunks[0].snippet.len(), 203); + assert!(!parsed.chunks.is_empty()); + // At least one chunk should have a truncated snippet + let has_truncated = parsed.chunks.iter().any(|c| c.snippet.ends_with("...")); + assert!( + has_truncated, + "Expected at least one snippet to be truncated" + ); + // Verify truncation length + for c in &parsed.chunks { + if c.snippet.ends_with("...") { + assert_eq!(c.snippet.len(), 203); + } + } } #[test] @@ -400,4 +875,39 @@ mod tests { assert_eq!(result[0].heading.as_deref(), Some("## Short")); assert_eq!(result[0].text, "## Short\nJust a few words here."); } + + #[test] + fn test_extract_heading() { + assert_eq!( + extract_heading("# Title\nBody text"), + Some("# Title".to_string()) + ); + assert_eq!(extract_heading("## Sub\nBody"), Some("## Sub".to_string())); + assert_eq!(extract_heading("No heading here"), None); + assert_eq!( + extract_heading("Some text\n### Deep heading\nMore"), + Some("### Deep heading".to_string()) + ); + } + + #[test] + fn test_thematic_break_detection() { + assert!(is_thematic_break("---")); + assert!(is_thematic_break("***")); + assert!(is_thematic_break("___")); + assert!(is_thematic_break("- - -")); + assert!(is_thematic_break("----")); + assert!(!is_thematic_break("--")); + assert!(!is_thematic_break("abc")); + } + + #[test] + fn test_list_item_detection() { + assert!(is_list_item("- item")); + assert!(is_list_item("* item")); + assert!(is_list_item("1. item")); + assert!(is_list_item("10. item")); + assert!(!is_list_item("plain text")); + assert!(!is_list_item("")); + } } diff --git a/src/config.rs b/src/config.rs index b1d1adb..4ceaf83 100644 --- a/src/config.rs +++ b/src/config.rs @@ -62,6 +62,12 @@ impl Config { self.top_n = n; } } + + /// Load vault profile from `~/.engraph/vault.toml`, if it exists. + pub fn load_vault_profile() -> Result> { + let dir = Self::data_dir()?; + crate::profile::load_vault_toml(&dir) + } } #[cfg(test)] diff --git a/src/docid.rs b/src/docid.rs new file mode 100644 index 0000000..8c471bd --- /dev/null +++ b/src/docid.rs @@ -0,0 +1,40 @@ +use sha2::{Digest, Sha256}; + +/// Generate a 6-character hex docid from a file path. +/// Deterministic: same path always produces same docid. +pub fn generate_docid(path: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(path.as_bytes()); + let hash = hasher.finalize(); + format!("{:02x}{:02x}{:02x}", hash[0], hash[1], hash[2]) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_docid_length_and_hex() { + let docid = generate_docid("notes/test.md"); + assert_eq!(docid.len(), 6, "docid should be 6 characters"); + assert!( + docid.chars().all(|c| c.is_ascii_hexdigit()), + "docid should be all hex chars, got: {}", + docid + ); + } + + #[test] + fn test_docid_deterministic() { + let a = generate_docid("notes/test.md"); + let b = generate_docid("notes/test.md"); + assert_eq!(a, b, "same path must produce same docid"); + } + + #[test] + fn test_docid_unique() { + let a = generate_docid("notes/a.md"); + let b = generate_docid("notes/b.md"); + assert_ne!(a, b, "different paths should produce different docids"); + } +} diff --git a/src/embedder.rs b/src/embedder.rs index 071c746..5770904 100644 --- a/src/embedder.rs +++ b/src/embedder.rs @@ -157,6 +157,28 @@ impl Embedder { } } +impl crate::model::ModelBackend for Embedder { + fn embed_batch(&mut self, texts: &[&str]) -> Result>> { + self.embed_batch(texts) + } + + fn embed_one(&mut self, text: &str) -> Result> { + self.embed_one(text) + } + + fn token_count(&self, text: &str) -> usize { + self.token_count(text) + } + + fn dim(&self) -> usize { + EMBEDDING_DIM + } + + fn name(&self) -> &str { + "onnx:all-MiniLM-L6-v2" + } +} + /// L2-normalize a vector. Returns a zero vector if input norm is zero. fn normalize_vector(v: &[f32]) -> Vec { let norm: f32 = v.iter().map(|x| x * x).sum::().sqrt(); diff --git a/src/fts.rs b/src/fts.rs new file mode 100644 index 0000000..670d1da --- /dev/null +++ b/src/fts.rs @@ -0,0 +1,136 @@ +// FTS5 search support. +// +// The `FtsResult` struct and `fts_search` method live on `Store` (in store.rs) +// since the store owns the database connection. We re-export `FtsResult` here +// so downstream code can import it from either location. + +pub use crate::store::FtsResult; + +#[cfg(test)] +mod tests { + use crate::docid::generate_docid; + use crate::store::Store; + + fn setup_store() -> Store { + let store = Store::open_memory().unwrap(); + store.ensure_fts_table().unwrap(); + store + } + + #[test] + fn test_fts_exact_match() { + let store = setup_store(); + let file_id = store + .insert_file( + "notes/ticket.md", + "hash1", + 100, + &[], + &generate_docid("notes/ticket.md"), + ) + .unwrap(); + + store + .insert_fts_chunk(file_id, 0, "BRE-2579 delivery date extension for checkout") + .unwrap(); + + let results = store.fts_search("BRE-2579", 10).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].file_id, file_id); + assert_eq!(results[0].chunk_seq, 0); + assert!( + results[0].score > 0.0, + "score should be positive (negated BM25)" + ); + } + + #[test] + fn test_fts_no_match() { + let store = setup_store(); + let file_id = store + .insert_file( + "notes/note.md", + "hash1", + 100, + &[], + &generate_docid("notes/note.md"), + ) + .unwrap(); + + store + .insert_fts_chunk(file_id, 0, "Rust programming language guide") + .unwrap(); + + let results = store.fts_search("kubernetes", 10).unwrap(); + assert_eq!(results.len(), 0); + } + + #[test] + fn test_fts_multiple_results() { + let store = setup_store(); + + let file_id1 = store + .insert_file("notes/a.md", "h1", 100, &[], &generate_docid("notes/a.md")) + .unwrap(); + let file_id2 = store + .insert_file("notes/b.md", "h2", 100, &[], &generate_docid("notes/b.md")) + .unwrap(); + let file_id3 = store + .insert_file("notes/c.md", "h3", 100, &[], &generate_docid("notes/c.md")) + .unwrap(); + + // Chunk with "delivery" appearing multiple times should rank higher. + store + .insert_fts_chunk( + file_id1, + 0, + "delivery date delivery schedule delivery tracking", + ) + .unwrap(); + store + .insert_fts_chunk(file_id2, 0, "delivery date for the checkout page") + .unwrap(); + store + .insert_fts_chunk(file_id3, 0, "unrelated content about Rust and WebAssembly") + .unwrap(); + + let results = store.fts_search("delivery", 10).unwrap(); + assert_eq!(results.len(), 2, "only 2 chunks mention 'delivery'"); + + // Results should be sorted by score descending. + assert!( + results[0].score >= results[1].score, + "results should be ranked by relevance" + ); + } + + #[test] + fn test_fts_delete_chunks_for_file() { + let store = setup_store(); + let file_id = store + .insert_file( + "notes/del.md", + "hash1", + 100, + &[], + &generate_docid("notes/del.md"), + ) + .unwrap(); + + store + .insert_fts_chunk(file_id, 0, "first chunk content") + .unwrap(); + store + .insert_fts_chunk(file_id, 1, "second chunk content") + .unwrap(); + + // Verify they exist. + let results = store.fts_search("chunk", 10).unwrap(); + assert_eq!(results.len(), 2); + + // Delete and verify gone. + store.delete_fts_chunks_for_file(file_id).unwrap(); + let results = store.fts_search("chunk", 10).unwrap(); + assert_eq!(results.len(), 0); + } +} diff --git a/src/fusion.rs b/src/fusion.rs new file mode 100644 index 0000000..f1c7508 --- /dev/null +++ b/src/fusion.rs @@ -0,0 +1,265 @@ +/// Reciprocal Rank Fusion (RRF) engine. +/// +/// Merges ranked results from multiple search lanes (e.g. semantic HNSW +/// and FTS5 keyword search) into a single ranked list using the RRF formula: +/// +/// rrf_score = sum( weight_i / (k + rank_i) ) +/// +/// A ranked result from a single search lane. +pub struct RankedResult { + pub file_path: String, + pub file_id: i64, + pub score: f64, + pub heading: Option, + pub snippet: String, + pub docid: Option, +} + +/// A fused result after RRF merging across lanes. +pub struct FusedResult { + pub file_path: String, + pub file_id: i64, + pub rrf_score: f64, + pub heading: Option, + pub snippet: String, + pub docid: Option, + pub lane_contributions: Vec, +} + +/// Per-lane contribution details for --explain output. +pub struct LaneContribution { + pub lane_name: String, + pub rank: usize, + pub raw_score: f64, + pub weighted_contribution: f64, +} + +use std::collections::HashMap; + +/// Fuse ranked results from multiple search lanes using Reciprocal Rank Fusion. +/// +/// Each lane is a tuple of `(lane_name, results, weight)`. +/// Results are grouped by `file_path` (file-level deduplication). +/// The best snippet/heading per file is kept from the highest-ranked lane. +/// +/// `k` is the RRF constant (typically 60). +pub fn rrf_fuse(lanes: &[(&str, &[RankedResult], f64)], k: usize) -> Vec { + // Track per-file: rrf_score, best snippet info, lane contributions + struct Accumulator { + file_path: String, + file_id: i64, + rrf_score: f64, + heading: Option, + snippet: String, + docid: Option, + best_rank: usize, // lowest rank seen (for picking best snippet) + lane_contributions: Vec, + } + + let mut acc_map: HashMap = HashMap::new(); + + for &(lane_name, results, weight) in lanes { + for (idx, r) in results.iter().enumerate() { + let rank = idx + 1; // 1-based + let contribution = weight / (k as f64 + rank as f64); + + let acc = acc_map + .entry(r.file_path.clone()) + .or_insert_with(|| Accumulator { + file_path: r.file_path.clone(), + file_id: r.file_id, + rrf_score: 0.0, + heading: r.heading.clone(), + snippet: r.snippet.clone(), + docid: r.docid.clone(), + best_rank: rank, + lane_contributions: Vec::new(), + }); + + acc.rrf_score += contribution; + + // Keep snippet from the best-ranked appearance + if rank < acc.best_rank { + acc.best_rank = rank; + acc.heading = r.heading.clone(); + acc.snippet = r.snippet.clone(); + if r.docid.is_some() { + acc.docid = r.docid.clone(); + } + } + + acc.lane_contributions.push(LaneContribution { + lane_name: lane_name.to_string(), + rank, + raw_score: r.score, + weighted_contribution: contribution, + }); + } + } + + let mut results: Vec = acc_map + .into_values() + .map(|a| FusedResult { + file_path: a.file_path, + file_id: a.file_id, + rrf_score: a.rrf_score, + heading: a.heading, + snippet: a.snippet, + docid: a.docid, + lane_contributions: a.lane_contributions, + }) + .collect(); + + // Sort by rrf_score descending + results.sort_by(|a, b| { + b.rrf_score + .partial_cmp(&a.rrf_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + results +} + +/// Format explain output for a single fused result. +pub fn format_explain(result: &FusedResult) -> String { + let mut out = format!(" RRF: {:.4}\n", result.rrf_score); + for lc in &result.lane_contributions { + out.push_str(&format!( + " {}: rank #{}, raw {:.2}, +{:.4}\n", + lc.lane_name, lc.rank, lc.raw_score, lc.weighted_contribution, + )); + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_result(file_path: &str, score: f64) -> RankedResult { + RankedResult { + file_path: file_path.to_string(), + file_id: 0, + score, + heading: Some(format!("heading for {}", file_path)), + snippet: format!("snippet for {}", file_path), + docid: None, + } + } + + #[test] + fn test_rrf_basic() { + // Item appearing in both lanes should rank highest + let semantic = vec![ + make_result("both.md", 0.87), + make_result("sem_only.md", 0.75), + ]; + let fts = vec![make_result("fts_only.md", 5.0), make_result("both.md", 3.2)]; + + let fused = rrf_fuse(&[("semantic", &semantic, 1.0), ("fts", &fts, 1.0)], 60); + + assert_eq!(fused.len(), 3); + // "both.md" should be first because it appears in both lanes + assert_eq!(fused[0].file_path, "both.md"); + + // Verify the RRF score for "both.md": + // semantic rank 1: 1.0 / (60 + 1) = 0.01639... + // fts rank 2: 1.0 / (60 + 2) = 0.01613... + // total = 0.03252... + let expected = 1.0 / 61.0 + 1.0 / 62.0; + assert!((fused[0].rrf_score - expected).abs() < 1e-10); + + // Both single-lane items should have lower scores + assert!(fused[0].rrf_score > fused[1].rrf_score); + assert!(fused[0].rrf_score > fused[2].rrf_score); + + // "both.md" should have 2 lane contributions + assert_eq!(fused[0].lane_contributions.len(), 2); + } + + #[test] + fn test_rrf_weighted() { + // FTS weighted 3x should make FTS-only item win over semantic-only item + let semantic = vec![make_result("sem.md", 0.95)]; + let fts = vec![make_result("fts.md", 8.0)]; + + let fused = rrf_fuse(&[("semantic", &semantic, 1.0), ("fts", &fts, 3.0)], 60); + + assert_eq!(fused.len(), 2); + // FTS item at rank 1 with weight 3.0: 3.0 / 61 = 0.04918... + // Semantic item at rank 1 with weight 1.0: 1.0 / 61 = 0.01639... + assert_eq!(fused[0].file_path, "fts.md"); + assert_eq!(fused[1].file_path, "sem.md"); + + let fts_expected = 3.0 / 61.0; + let sem_expected = 1.0 / 61.0; + assert!((fused[0].rrf_score - fts_expected).abs() < 1e-10); + assert!((fused[1].rrf_score - sem_expected).abs() < 1e-10); + } + + #[test] + fn test_rrf_single_lane() { + let semantic = vec![ + make_result("a.md", 0.9), + make_result("b.md", 0.8), + make_result("c.md", 0.7), + ]; + + let fused = rrf_fuse(&[("semantic", &semantic, 1.0)], 60); + + assert_eq!(fused.len(), 3); + assert_eq!(fused[0].file_path, "a.md"); + assert_eq!(fused[1].file_path, "b.md"); + assert_eq!(fused[2].file_path, "c.md"); + + // Each should have exactly 1 lane contribution + for f in &fused { + assert_eq!(f.lane_contributions.len(), 1); + assert_eq!(f.lane_contributions[0].lane_name, "semantic"); + } + } + + #[test] + fn test_format_explain() { + let result = FusedResult { + file_path: "test.md".to_string(), + file_id: 1, + rrf_score: 0.0328, + heading: None, + snippet: "test".to_string(), + docid: None, + lane_contributions: vec![ + LaneContribution { + lane_name: "semantic".to_string(), + rank: 1, + raw_score: 0.87, + weighted_contribution: 0.0164, + }, + LaneContribution { + lane_name: "fts".to_string(), + rank: 3, + raw_score: 5.23, + weighted_contribution: 0.0159, + }, + ], + }; + + let output = format_explain(&result); + assert!(output.contains("RRF: 0.0328")); + assert!(output.contains("semantic: rank #1, raw 0.87, +0.0164")); + assert!(output.contains("fts: rank #3, raw 5.23, +0.0159")); + } + + #[test] + fn test_rrf_empty_lanes() { + let fused = rrf_fuse(&[], 60); + assert!(fused.is_empty()); + } + + #[test] + fn test_rrf_empty_results() { + let empty: Vec = vec![]; + let fused = rrf_fuse(&[("semantic", &empty, 1.0), ("fts", &empty, 1.0)], 60); + assert!(fused.is_empty()); + } +} diff --git a/src/indexer.rs b/src/indexer.rs index 4c229b2..c2f69ff 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -10,6 +10,7 @@ use tracing::info; use crate::chunker::{chunk_markdown, split_oversized_chunks}; use crate::config::Config; +use crate::docid::generate_docid; use crate::embedder::Embedder; use crate::hnsw::HnswIndex; use crate::store::{FileRecord, Store}; @@ -166,6 +167,7 @@ pub fn run_index(vault_path: &Path, config: &Config, rebuild: bool) -> Result Result Result Result, + + /// Show per-lane RRF score breakdown for each result. + #[arg(long, conflicts_with = "json")] + explain: bool, }, /// Show index status and statistics. @@ -60,6 +66,29 @@ enum Command { #[arg(long)] all: bool, }, + + /// Initialize vault profile with auto-detection. + Init { + /// Path to the vault (defaults to current directory). + path: Option, + }, + + /// Interactively configure vault profile. + Configure, + + /// Manage embedding models. + Models { + #[command(subcommand)] + action: ModelsAction, + }, +} + +#[derive(Subcommand, Debug)] +enum ModelsAction { + /// List available models. + List, + /// Show info about a model. + Info { name: String }, } /// Check whether an index has been built by looking for engraph.db in data_dir. @@ -161,7 +190,11 @@ fn main() -> Result<()> { ); } - Command::Search { query, top_n } => { + Command::Search { + query, + top_n, + explain, + } => { cfg.merge_top_n(top_n); if !index_exists(&data_dir) { @@ -169,7 +202,7 @@ fn main() -> Result<()> { std::process::exit(1); } - search::run_search(&query, cfg.top_n, cli.json, &data_dir)?; + search::run_search(&query, cfg.top_n, cli.json, explain, &data_dir)?; } Command::Status => { @@ -210,6 +243,105 @@ fn main() -> Result<()> { } } } + + Command::Init { path } => { + // Resolve vault path: CLI arg > config > cwd. + cfg.merge_vault_path(path); + let vault_path = match &cfg.vault_path { + Some(p) => p.clone(), + None => std::env::current_dir()?, + }; + let vault_path = vault_path.canonicalize().unwrap_or(vault_path); + + println!("Detecting vault profile for: {}", vault_path.display()); + + let vault_type = profile::detect_vault_type(&vault_path); + let structure = profile::detect_structure(&vault_path)?; + let stats = profile::scan_vault_stats(&vault_path)?; + + // Print detection results. + println!(); + println!(" Vault type: {:?}", vault_type); + println!(" Structure: {:?}", structure.method); + if let Some(ref inbox) = structure.folders.inbox { + println!(" inbox: {}", inbox); + } + if let Some(ref projects) = structure.folders.projects { + println!(" projects: {}", projects); + } + if let Some(ref areas) = structure.folders.areas { + println!(" areas: {}", areas); + } + if let Some(ref resources) = structure.folders.resources { + println!(" resources: {}", resources); + } + if let Some(ref archive) = structure.folders.archive { + println!(" archive: {}", archive); + } + if let Some(ref templates) = structure.folders.templates { + println!(" templates: {}", templates); + } + if let Some(ref daily) = structure.folders.daily { + println!(" daily: {}", daily); + } + if let Some(ref people) = structure.folders.people { + println!(" people: {}", people); + } + println!(); + println!(" Total .md files: {}", stats.total_files); + println!(" With frontmatter: {}", stats.files_with_frontmatter); + println!(" Wikilinks: {}", stats.wikilink_count); + println!(" Unique tags: {}", stats.unique_tags); + println!(" Folders: {}", stats.folder_count); + println!(" Max folder depth: {}", stats.folder_depth); + + let vault_profile = profile::VaultProfile { + vault_path, + vault_type, + structure, + stats, + }; + + // Ensure data dir exists and write vault.toml. + std::fs::create_dir_all(&data_dir)?; + profile::write_vault_toml(&vault_profile, &data_dir)?; + + println!(); + println!("Wrote {}", data_dir.join("vault.toml").display()); + } + + Command::Configure => { + println!( + "Interactive configuration not yet implemented. Run 'engraph init' for auto-detection." + ); + } + + Command::Models { action } => { + let registry = model::ModelRegistry::default(); + match action { + ModelsAction::List => { + println!("{:<30} {:>5} DESCRIPTION", "NAME", "DIM"); + println!("{}", "-".repeat(70)); + for entry in ®istry.entries { + println!("{:<30} {:>5} {}", entry.name, entry.dim, entry.description); + } + } + ModelsAction::Info { name } => { + if let Some(entry) = registry.get(&name) { + println!("Name: {}", entry.name); + println!("Format: {:?}", entry.format); + println!("Dimensions: {}", entry.dim); + println!("SHA-256: {}", entry.sha256); + println!("URL: {}", entry.url); + println!("Description: {}", entry.description); + } else { + eprintln!("Unknown model: {name}"); + eprintln!("Run 'engraph models list' to see available models."); + std::process::exit(1); + } + } + } + } } Ok(()) diff --git a/src/model.rs b/src/model.rs new file mode 100644 index 0000000..b8c966c --- /dev/null +++ b/src/model.rs @@ -0,0 +1,141 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; + +/// Trait for embedding backends. Any model that can embed text implements this. +pub trait ModelBackend { + fn embed_batch(&mut self, texts: &[&str]) -> Result>>; + fn embed_one(&mut self, text: &str) -> Result>; + fn token_count(&self, text: &str) -> usize; + fn dim(&self) -> usize; + fn name(&self) -> &str; +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ModelFormat { + Onnx, + Gguf, + File, +} + +#[derive(Debug, Clone)] +pub struct ModelSpec { + pub format: ModelFormat, + pub name: String, + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelRegistryEntry { + pub name: String, + pub format: ModelFormat, + pub url: String, + pub sha256: String, + pub dim: usize, + pub description: String, +} + +pub struct ModelRegistry { + pub entries: Vec, +} + +impl Default for ModelRegistry { + fn default() -> Self { + Self { + entries: vec![ModelRegistryEntry { + name: "onnx:all-MiniLM-L6-v2".to_string(), + format: ModelFormat::Onnx, + url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".to_string(), + sha256: "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452".to_string(), + dim: 384, + description: "Lightweight general-purpose sentence embeddings".to_string(), + }], + } + } +} + +impl ModelRegistry { + pub fn get(&self, name: &str) -> Option<&ModelRegistryEntry> { + self.entries.iter().find(|e| e.name == name) + } +} + +pub fn parse_model_spec(spec: &str) -> ModelSpec { + if let Some(path) = spec.strip_prefix("file:") { + return ModelSpec { + format: ModelFormat::File, + name: spec.to_string(), + path: path.to_string(), + }; + } + if let Some((format_str, name)) = spec.split_once(':') { + let format = match format_str { + "onnx" => ModelFormat::Onnx, + "gguf" => ModelFormat::Gguf, + _ => ModelFormat::Onnx, + }; + ModelSpec { + format, + name: name.to_string(), + path: String::new(), + } + } else { + ModelSpec { + format: ModelFormat::Onnx, + name: spec.to_string(), + path: String::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_registry_default() { + let registry = ModelRegistry::default(); + assert_eq!(registry.entries.len(), 1); + let entry = ®istry.entries[0]; + assert_eq!(entry.name, "onnx:all-MiniLM-L6-v2"); + assert_eq!(entry.dim, 384); + assert_eq!(entry.format, ModelFormat::Onnx); + } + + #[test] + fn test_parse_model_spec_onnx() { + let spec = parse_model_spec("onnx:all-MiniLM-L6-v2"); + assert_eq!(spec.format, ModelFormat::Onnx); + assert_eq!(spec.name, "all-MiniLM-L6-v2"); + assert!(spec.path.is_empty()); + } + + #[test] + fn test_parse_model_spec_file() { + let spec = parse_model_spec("file:/path/to/model.onnx"); + assert_eq!(spec.format, ModelFormat::File); + assert_eq!(spec.name, "file:/path/to/model.onnx"); + assert_eq!(spec.path, "/path/to/model.onnx"); + } + + #[test] + fn test_parse_model_spec_bare() { + let spec = parse_model_spec("my-custom-model"); + assert_eq!(spec.format, ModelFormat::Onnx); + assert_eq!(spec.name, "my-custom-model"); + assert!(spec.path.is_empty()); + } + + #[test] + fn test_registry_get_existing() { + let registry = ModelRegistry::default(); + let entry = registry.get("onnx:all-MiniLM-L6-v2"); + assert!(entry.is_some()); + assert_eq!(entry.unwrap().dim, 384); + } + + #[test] + fn test_registry_get_missing() { + let registry = ModelRegistry::default(); + assert!(registry.get("nonexistent-model").is_none()); + } +} diff --git a/src/profile.rs b/src/profile.rs new file mode 100644 index 0000000..07dbc89 --- /dev/null +++ b/src/profile.rs @@ -0,0 +1,635 @@ +use std::path::{Path, PathBuf}; + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; + +/// How the vault organizes its notes. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum StructureMethod { + Flat, + Folders, + Para, + Custom, +} + +/// What kind of vault this is. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum VaultType { + Obsidian, + Logseq, + Plain, + Custom, +} + +/// Complete vault profile, persisted to `vault.toml`. +#[derive(Debug, Serialize, Deserialize)] +pub struct VaultProfile { + pub vault_path: PathBuf, + pub vault_type: VaultType, + pub structure: StructureDetection, + pub stats: VaultStats, +} + +/// Detected folder structure. +#[derive(Debug, Serialize, Deserialize)] +pub struct StructureDetection { + pub method: StructureMethod, + pub folders: FolderMap, +} + +/// Known folder roles mapped to their detected names. +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct FolderMap { + pub inbox: Option, + pub projects: Option, + pub areas: Option, + pub resources: Option, + pub archive: Option, + pub templates: Option, + pub daily: Option, + pub people: Option, +} + +/// Aggregate vault statistics. +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct VaultStats { + pub total_files: usize, + pub files_with_frontmatter: usize, + pub wikilink_count: usize, + pub unique_tags: usize, + pub folder_depth: usize, + pub folder_count: usize, +} + +// --------------------------------------------------------------------------- +// Detection helpers +// --------------------------------------------------------------------------- + +/// Map a folder name (case-insensitive, ignoring leading number prefixes like `00-`) +/// to a PARA role. +fn para_role(name: &str) -> Option<&'static str> { + // Strip optional leading digits and separator (e.g. "00-", "01-"). + let stripped = name + .trim_start_matches(|c: char| c.is_ascii_digit()) + .trim_start_matches(['-', '_', ' ']); + + match stripped.to_ascii_lowercase().as_str() { + "inbox" => Some("inbox"), + "projects" => Some("projects"), + "areas" => Some("areas"), + "resources" => Some("resources"), + "archive" => Some("archive"), + "templates" => Some("templates"), + "daily" => Some("daily"), + "people" => Some("people"), + _ => None, + } +} + +/// Detect vault structure by checking for PARA-style numbered folders. +/// +/// - If at least 3 of the 4 core PARA folders (inbox, projects, areas, resources) exist -> Para +/// - If there are subdirectories but no PARA pattern -> Folders +/// - If mostly flat .md files -> Flat +pub fn detect_structure(path: &Path) -> Result { + let mut folders = FolderMap::default(); + let mut para_hits = 0u32; + let mut dir_count = 0usize; + + let entries = std::fs::read_dir(path) + .with_context(|| format!("cannot read directory {}", path.display()))?; + + for entry in entries { + let entry = entry?; + let ft = entry.file_type()?; + if !ft.is_dir() { + continue; + } + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + // Skip hidden directories. + if name_str.starts_with('.') { + continue; + } + + dir_count += 1; + + if let Some(role) = para_role(&name_str) { + let folder_name = name_str.to_string(); + match role { + "inbox" => { + folders.inbox = Some(folder_name); + para_hits += 1; + } + "projects" => { + folders.projects = Some(folder_name); + para_hits += 1; + } + "areas" => { + folders.areas = Some(folder_name); + para_hits += 1; + } + "resources" => { + folders.resources = Some(folder_name); + para_hits += 1; + } + "archive" => { + folders.archive = Some(folder_name); + } + "templates" => { + folders.templates = Some(folder_name); + } + "daily" => { + folders.daily = Some(folder_name); + } + "people" => { + folders.people = Some(folder_name); + } + _ => {} + } + } + } + + let method = if para_hits >= 3 { + StructureMethod::Para + } else if dir_count > 0 { + StructureMethod::Folders + } else { + StructureMethod::Flat + }; + + Ok(StructureDetection { method, folders }) +} + +/// Count wikilinks (`[[...]]`) in a string. Handles nested brackets conservatively. +fn count_wikilinks(text: &str) -> usize { + let bytes = text.as_bytes(); + let mut count = 0usize; + let mut i = 0; + while i + 1 < bytes.len() { + if bytes[i] == b'[' && bytes[i + 1] == b'[' { + // Find closing ]]. + if let Some(rest) = text.get(i + 2..) + && let Some(close) = rest.find("]]") + { + // Only count if the content is non-empty and doesn't span lines. + let inner = &rest[..close]; + if !inner.is_empty() && !inner.contains('\n') { + count += 1; + } + i += 2 + close + 2; + continue; + } + } + i += 1; + } + count +} + +/// Check whether a file starts with YAML frontmatter (`---` on the first line). +fn has_frontmatter(text: &str) -> bool { + text.starts_with("---\n") || text.starts_with("---\r\n") +} + +/// Extract tags from YAML frontmatter. Handles both list and inline formats: +/// ```yaml +/// tags: +/// - foo +/// - bar +/// ``` +/// and +/// ```yaml +/// tags: [foo, bar] +/// ``` +fn extract_tags(text: &str) -> Vec { + // Find frontmatter block. + let fm = if text.starts_with("---\n") { + text.get(4..) + .and_then(|rest| rest.find("\n---").map(|end| &rest[..end])) + } else if text.starts_with("---\r\n") { + text.get(5..) + .and_then(|rest| rest.find("\n---").map(|end| &rest[..end])) + } else { + None + }; + + let Some(fm) = fm else { + return Vec::new(); + }; + + let mut tags = Vec::new(); + let mut in_tags_block = false; + + for line in fm.lines() { + let trimmed = line.trim(); + + if trimmed.starts_with("tags:") { + let after = trimmed.strip_prefix("tags:").unwrap().trim(); + if after.is_empty() { + // Multi-line list follows. + in_tags_block = true; + continue; + } + // Inline list: tags: [a, b] or tags: a, b + let after = after.trim_start_matches('[').trim_end_matches(']'); + for tag in after.split(',') { + let t = tag + .trim() + .trim_matches('"') + .trim_matches('\'') + .trim_matches('#'); + if !t.is_empty() { + tags.push(t.to_string()); + } + } + return tags; + } + + if in_tags_block { + if trimmed.starts_with("- ") { + let t = trimmed + .strip_prefix("- ") + .unwrap() + .trim() + .trim_matches('"') + .trim_matches('\'') + .trim_matches('#'); + if !t.is_empty() { + tags.push(t.to_string()); + } + } else if !trimmed.is_empty() { + // End of tags block (new key). + break; + } + } + } + + tags +} + +/// Walk all `.md` files under `path` recursively. +fn walk_md_files(path: &Path) -> Result> { + let mut files = Vec::new(); + walk_md_recursive(path, &mut files)?; + Ok(files) +} + +fn walk_md_recursive(dir: &Path, out: &mut Vec) -> Result<()> { + let entries = std::fs::read_dir(dir) + .with_context(|| format!("cannot read directory {}", dir.display()))?; + + for entry in entries { + let entry = entry?; + let ft = entry.file_type()?; + let path = entry.path(); + + if ft.is_dir() { + // Skip hidden directories. + if entry.file_name().to_string_lossy().starts_with('.') { + continue; + } + walk_md_recursive(&path, out)?; + } else if ft.is_file() + && let Some(ext) = path.extension() + && ext == "md" + { + out.push(path); + } + } + + Ok(()) +} + +/// Count distinct folders and maximum depth relative to `root`. +fn folder_stats(root: &Path) -> Result<(usize, usize)> { + let mut count = 0usize; + let mut max_depth = 0usize; + folder_stats_recursive(root, root, &mut count, &mut max_depth)?; + Ok((count, max_depth)) +} + +fn folder_stats_recursive( + dir: &Path, + root: &Path, + count: &mut usize, + max_depth: &mut usize, +) -> Result<()> { + let entries = std::fs::read_dir(dir) + .with_context(|| format!("cannot read directory {}", dir.display()))?; + + for entry in entries { + let entry = entry?; + if !entry.file_type()?.is_dir() { + continue; + } + let name = entry.file_name(); + if name.to_string_lossy().starts_with('.') { + continue; + } + *count += 1; + let depth = entry + .path() + .strip_prefix(root) + .map(|p| p.components().count()) + .unwrap_or(0); + if depth > *max_depth { + *max_depth = depth; + } + folder_stats_recursive(&entry.path(), root, count, max_depth)?; + } + + Ok(()) +} + +/// Scan vault files for statistics. +pub fn scan_vault_stats(path: &Path) -> Result { + let md_files = walk_md_files(path)?; + let mut all_tags = std::collections::HashSet::new(); + let mut files_with_frontmatter = 0; + let mut wikilink_count = 0; + + for file in &md_files { + let text = std::fs::read_to_string(file).unwrap_or_default(); + if has_frontmatter(&text) { + files_with_frontmatter += 1; + } + wikilink_count += count_wikilinks(&text); + for tag in extract_tags(&text) { + all_tags.insert(tag); + } + } + + let (fc, fd) = folder_stats(path)?; + + Ok(VaultStats { + total_files: md_files.len(), + files_with_frontmatter, + wikilink_count, + unique_tags: all_tags.len(), + folder_count: fc, + folder_depth: fd, + }) +} + +/// Detect vault type based on marker files/directories. +pub fn detect_vault_type(path: &Path) -> VaultType { + if path.join(".obsidian").is_dir() { + VaultType::Obsidian + } else if path.join(".logseq").is_dir() { + VaultType::Logseq + } else { + VaultType::Plain + } +} + +/// Write vault profile to `vault.toml` in the given directory. +pub fn write_vault_toml(profile: &VaultProfile, config_dir: &Path) -> Result<()> { + std::fs::create_dir_all(config_dir) + .with_context(|| format!("cannot create directory {}", config_dir.display()))?; + + let toml_str = toml::to_string_pretty(profile).context("failed to serialize vault profile")?; + let dest = config_dir.join("vault.toml"); + std::fs::write(&dest, toml_str) + .with_context(|| format!("failed to write {}", dest.display()))?; + + Ok(()) +} + +/// Load vault profile from `vault.toml` in the given directory. +/// Returns `Ok(None)` if the file does not exist. +pub fn load_vault_toml(config_dir: &Path) -> Result> { + let path = config_dir.join("vault.toml"); + if !path.exists() { + return Ok(None); + } + + let contents = std::fs::read_to_string(&path) + .with_context(|| format!("failed to read {}", path.display()))?; + let profile: VaultProfile = + toml::from_str(&contents).with_context(|| format!("failed to parse {}", path.display()))?; + + Ok(Some(profile)) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_detect_para_structure() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + + // Create PARA-style numbered folders. + std::fs::create_dir(root.join("00-Inbox")).unwrap(); + std::fs::create_dir(root.join("01-Projects")).unwrap(); + std::fs::create_dir(root.join("02-Areas")).unwrap(); + std::fs::create_dir(root.join("03-Resources")).unwrap(); + std::fs::create_dir(root.join("04-Archive")).unwrap(); + std::fs::create_dir(root.join("05-Templates")).unwrap(); + std::fs::create_dir(root.join("07-Daily")).unwrap(); + + let result = detect_structure(root).unwrap(); + assert_eq!(result.method, StructureMethod::Para); + assert_eq!(result.folders.inbox.as_deref(), Some("00-Inbox")); + assert_eq!(result.folders.projects.as_deref(), Some("01-Projects")); + assert_eq!(result.folders.areas.as_deref(), Some("02-Areas")); + assert_eq!(result.folders.resources.as_deref(), Some("03-Resources")); + assert_eq!(result.folders.archive.as_deref(), Some("04-Archive")); + assert_eq!(result.folders.templates.as_deref(), Some("05-Templates")); + assert_eq!(result.folders.daily.as_deref(), Some("07-Daily")); + } + + #[test] + fn test_detect_flat_structure() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + + // Create only .md files, no subdirectories. + std::fs::write(root.join("note1.md"), "Hello").unwrap(); + std::fs::write(root.join("note2.md"), "World").unwrap(); + + let result = detect_structure(root).unwrap(); + assert_eq!(result.method, StructureMethod::Flat); + } + + #[test] + fn test_detect_folders_structure() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + + // Create non-PARA subdirectories. + std::fs::create_dir(root.join("notes")).unwrap(); + std::fs::create_dir(root.join("journal")).unwrap(); + std::fs::create_dir(root.join("references")).unwrap(); + + let result = detect_structure(root).unwrap(); + assert_eq!(result.method, StructureMethod::Folders); + } + + #[test] + fn test_detect_wikilinks() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + + std::fs::write( + root.join("a.md"), + "See [[Note One]] and [[Note Two]] for details.", + ) + .unwrap(); + std::fs::write(root.join("b.md"), "No links here.").unwrap(); + std::fs::write(root.join("c.md"), "Link to [[Note One|alias]] only.").unwrap(); + + let stats = scan_vault_stats(root).unwrap(); + assert_eq!(stats.total_files, 3); + assert_eq!(stats.wikilink_count, 3); + } + + #[test] + fn test_detect_frontmatter() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + + std::fs::write( + root.join("with_fm.md"), + "---\ntitle: Test\ntags:\n - foo\n---\nContent here.", + ) + .unwrap(); + std::fs::write(root.join("without_fm.md"), "Just some text.").unwrap(); + std::fs::write( + root.join("also_fm.md"), + "---\ntags: [bar, baz]\n---\nMore content.", + ) + .unwrap(); + + let stats = scan_vault_stats(root).unwrap(); + assert_eq!(stats.total_files, 3); + assert_eq!(stats.files_with_frontmatter, 2); + assert_eq!(stats.unique_tags, 3); // foo, bar, baz + } + + #[test] + fn test_detect_vault_type_obsidian() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + std::fs::create_dir(root.join(".obsidian")).unwrap(); + + assert_eq!(detect_vault_type(root), VaultType::Obsidian); + } + + #[test] + fn test_detect_vault_type_logseq() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + std::fs::create_dir(root.join(".logseq")).unwrap(); + + assert_eq!(detect_vault_type(root), VaultType::Logseq); + } + + #[test] + fn test_detect_vault_type_plain() { + let tmp = TempDir::new().unwrap(); + assert_eq!(detect_vault_type(tmp.path()), VaultType::Plain); + } + + #[test] + fn test_write_and_load_vault_toml() { + let tmp = TempDir::new().unwrap(); + let config_dir = tmp.path(); + + let profile = VaultProfile { + vault_path: PathBuf::from("/test/vault"), + vault_type: VaultType::Obsidian, + structure: StructureDetection { + method: StructureMethod::Para, + folders: FolderMap { + inbox: Some("00-Inbox".to_string()), + projects: Some("01-Projects".to_string()), + areas: Some("02-Areas".to_string()), + resources: Some("03-Resources".to_string()), + archive: Some("04-Archive".to_string()), + templates: Some("05-Templates".to_string()), + daily: Some("07-Daily".to_string()), + people: None, + }, + }, + stats: VaultStats { + total_files: 100, + files_with_frontmatter: 80, + wikilink_count: 500, + unique_tags: 25, + folder_depth: 3, + folder_count: 10, + }, + }; + + write_vault_toml(&profile, config_dir).unwrap(); + + // Verify the file was created. + assert!(config_dir.join("vault.toml").exists()); + + // Load it back. + let loaded = load_vault_toml(config_dir).unwrap().unwrap(); + assert_eq!(loaded.vault_path, PathBuf::from("/test/vault")); + assert_eq!(loaded.vault_type, VaultType::Obsidian); + assert_eq!(loaded.structure.method, StructureMethod::Para); + assert_eq!(loaded.structure.folders.inbox.as_deref(), Some("00-Inbox")); + assert_eq!(loaded.stats.total_files, 100); + assert_eq!(loaded.stats.wikilink_count, 500); + assert_eq!(loaded.stats.unique_tags, 25); + } + + #[test] + fn test_load_vault_toml_missing_file() { + let tmp = TempDir::new().unwrap(); + let result = load_vault_toml(tmp.path()).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_count_wikilinks_empty() { + assert_eq!(count_wikilinks(""), 0); + assert_eq!(count_wikilinks("no links here"), 0); + } + + #[test] + fn test_count_wikilinks_multiple() { + assert_eq!(count_wikilinks("[[a]] text [[b|alias]] more [[c]]"), 3); + } + + #[test] + fn test_extract_tags_inline_list() { + let text = "---\ntags: [foo, bar, baz]\n---\ncontent"; + assert_eq!(extract_tags(text), vec!["foo", "bar", "baz"]); + } + + #[test] + fn test_extract_tags_multiline() { + let text = "---\ntags:\n - alpha\n - beta\n---\ncontent"; + assert_eq!(extract_tags(text), vec!["alpha", "beta"]); + } + + #[test] + fn test_extract_tags_no_frontmatter() { + let text = "just some text"; + assert!(extract_tags(text).is_empty()); + } + + #[test] + fn test_folder_stats_depth() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + std::fs::create_dir_all(root.join("a/b/c")).unwrap(); + std::fs::create_dir(root.join("d")).unwrap(); + + let (count, depth) = folder_stats(root).unwrap(); + assert_eq!(count, 4); // a, a/b, a/b/c, d + assert_eq!(depth, 3); // a/b/c is depth 3 + } +} diff --git a/src/search.rs b/src/search.rs index bcbaf90..40376a5 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,9 +1,11 @@ +use std::collections::HashMap; use std::path::Path; use anyhow::{Context, Result}; use serde_json::json; use crate::embedder::Embedder; +use crate::fusion::{self, RankedResult}; use crate::hnsw::HnswIndex; use crate::store::{Store, StoreStats}; @@ -13,10 +15,21 @@ pub struct SearchResult { pub file_path: String, pub heading: Option, pub snippet: String, + pub docid: Option, } /// Run a search query and print results. -pub fn run_search(query: &str, top_n: usize, json: bool, data_dir: &Path) -> Result<()> { +/// +/// Performs both semantic (HNSW) and keyword (FTS5) search, then fuses +/// results using Reciprocal Rank Fusion. When `explain` is true, each +/// result includes per-lane score breakdown. +pub fn run_search( + query: &str, + top_n: usize, + json: bool, + explain: bool, + data_dir: &Path, +) -> Result<()> { let models_dir = data_dir.join("models"); let mut embedder = Embedder::new(&models_dir).context("loading embedder")?; @@ -26,38 +39,129 @@ pub fn run_search(query: &str, top_n: usize, json: bool, data_dir: &Path) -> Res let db_path = data_dir.join("engraph.db"); let store = Store::open(&db_path).context("opening store")?; + // --- Semantic lane --- let query_vec = embedder.embed_one(query).context("embedding query")?; - let tombstones = store.get_tombstones().context("loading tombstones")?; - // Request extra results to account for tombstone filtering. - let raw_results = index.search(&query_vec, top_n, &tombstones); + // Request extra results to account for tombstone filtering and file-level dedup. + let raw_results = index.search(&query_vec, top_n * 3, &tombstones); - let mut results = Vec::new(); + // Group semantic results by file_path, keeping best per file. + let mut sem_by_file: HashMap = HashMap::new(); for (vector_id, distance) in raw_results { if let Some(chunk) = store.get_chunk_by_vector_id(vector_id)? { - let file_path = store - .get_file_path_by_id(chunk.file_id)? - .unwrap_or_else(|| "".to_string()); - - // Convert cosine distance to similarity score. - let score = 1.0 - distance; + let (file_path, docid) = match store.get_file_by_id(chunk.file_id)? { + Some(f) => (f.path, f.docid), + None => ("".to_string(), None), + }; + let score = (1.0 - distance) as f64; let heading = if chunk.heading.is_empty() { None } else { Some(chunk.heading) }; - results.push(SearchResult { - score, - file_path, - heading, - snippet: chunk.snippet, - }); + // Keep the best-scoring chunk per file. + let better = match sem_by_file.get(&file_path) { + Some(existing) => score > existing.score, + None => true, + }; + if better { + sem_by_file.insert( + file_path.clone(), + RankedResult { + file_path, + file_id: chunk.file_id, + score, + heading, + snippet: chunk.snippet, + docid, + }, + ); + } + } + } + + // Sort semantic results by score descending for rank assignment. + let mut semantic_results: Vec = sem_by_file.into_values().collect(); + semantic_results.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // --- FTS lane --- + let fts_raw = store.fts_search(query, top_n * 3).unwrap_or_default(); + + // Group FTS results by file_path, keeping best per file. + let mut fts_by_file: HashMap = HashMap::new(); + for fr in fts_raw { + let (file_path, docid) = match store.get_file_by_id(fr.file_id)? { + Some(f) => (f.path, f.docid), + None => continue, + }; + + let better = match fts_by_file.get(&file_path) { + Some(existing) => fr.score > existing.score, + None => true, + }; + if better { + fts_by_file.insert( + file_path.clone(), + RankedResult { + file_path, + file_id: fr.file_id, + score: fr.score, + heading: None, // FTS doesn't return headings + snippet: fr.snippet, + docid, + }, + ); + } + } + + let mut fts_results: Vec = fts_by_file.into_values().collect(); + fts_results.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // --- RRF Fusion --- + const RRF_K: usize = 60; + let fused = fusion::rrf_fuse( + &[ + ("semantic", &semantic_results, 1.0), + ("fts", &fts_results, 1.0), + ], + RRF_K, + ); + + // Convert to SearchResult, taking top_n. + let results: Vec = fused + .iter() + .take(top_n) + .map(|f| SearchResult { + score: f.rrf_score as f32, + file_path: f.file_path.clone(), + heading: f.heading.clone(), + snippet: f.snippet.clone(), + docid: f.docid.clone(), + }) + .collect(); + + let mut output = format_results(&results, json); + + if explain && !json { + // Append explain info after results. + let mut explain_out = String::from("\n--- Explain ---\n"); + for f in fused.iter().take(top_n) { + explain_out.push_str(&format!("{}\n", f.file_path)); + explain_out.push_str(&fusion::format_explain(f)); } + output.push_str(&explain_out); } - let output = format_results(&results, json); print!("{output}"); Ok(()) } @@ -82,7 +186,11 @@ pub fn run_status(json: bool, data_dir: &Path) -> Result<()> { /// Format search results for display (pure function, no I/O). pub fn format_results(results: &[SearchResult], json: bool) -> String { if results.is_empty() { - return "No results found.\n".to_string(); + return if json { + "[]\n".to_string() + } else { + "No results found.\n".to_string() + }; } if json { @@ -98,6 +206,7 @@ pub fn format_results(results: &[SearchResult], json: bool) -> String { "file": r.file_path, "heading": r.heading, "snippet": r.snippet, + "docid": r.docid, }) }) .collect(); @@ -109,13 +218,18 @@ pub fn format_results(results: &[SearchResult], json: bool) -> String { Some(h) => format!(" > {h}"), None => String::new(), }; + let docid_part = match &r.docid { + Some(d) => format!(" #{d}"), + None => String::new(), + }; let snippet = truncate_snippet(&r.snippet, 200); out.push_str(&format!( - "{:>2}. [{:.2}] {}{}\n {}\n", + "{:>2}. [{:.2}] {}{}{}\n {}\n", i + 1, r.score, r.file_path, heading_part, + docid_part, snippet, )); } @@ -219,6 +333,23 @@ mod tests { file_path: "foo.md".to_string(), heading: Some("## Bar".to_string()), snippet: "Some text...".to_string(), + docid: Some("ab12cd".to_string()), + }]; + let output = format_results(&results, false); + assert_eq!( + output, + " 1. [0.87] foo.md > ## Bar #ab12cd\n Some text...\n" + ); + } + + #[test] + fn test_format_human_result_no_docid() { + let results = vec![SearchResult { + score: 0.87, + file_path: "foo.md".to_string(), + heading: Some("## Bar".to_string()), + snippet: "Some text...".to_string(), + docid: None, }]; let output = format_results(&results, false); assert_eq!(output, " 1. [0.87] foo.md > ## Bar\n Some text...\n"); @@ -231,6 +362,7 @@ mod tests { file_path: "foo.md".to_string(), heading: Some("## Bar".to_string()), snippet: "Some text...".to_string(), + docid: Some("ab12cd".to_string()), }]; let output = format_results(&results, true); let parsed: Vec = serde_json::from_str(&output).unwrap(); @@ -240,6 +372,7 @@ mod tests { assert_eq!(parsed[0]["file"], "foo.md"); assert_eq!(parsed[0]["heading"], "## Bar"); assert_eq!(parsed[0]["snippet"], "Some text..."); + assert_eq!(parsed[0]["docid"], "ab12cd"); } #[test] @@ -248,7 +381,7 @@ mod tests { assert_eq!(output, "No results found.\n"); let json_output = format_results(&[], true); - assert_eq!(json_output, "No results found.\n"); + assert_eq!(json_output, "[]\n"); } #[test] diff --git a/src/store.rs b/src/store.rs index 92d7d83..30d1eef 100644 --- a/src/store.rs +++ b/src/store.rs @@ -12,6 +12,7 @@ pub struct FileRecord { pub mtime: i64, pub tags: Vec, pub indexed_at: String, + pub docid: Option, } /// A record representing a chunk of a file. @@ -25,6 +26,15 @@ pub struct ChunkRecord { pub token_count: i64, } +/// A single result from an FTS5 full-text search. +#[derive(Debug, Clone)] +pub struct FtsResult { + pub file_id: i64, + pub chunk_seq: i64, + pub score: f64, + pub snippet: String, +} + /// Summary statistics for the store. #[derive(Debug)] pub struct StoreStats { @@ -49,7 +59,8 @@ CREATE TABLE IF NOT EXISTS files ( content_hash TEXT NOT NULL, mtime INTEGER NOT NULL, tags TEXT NOT NULL DEFAULT '[]', - indexed_at TEXT NOT NULL + indexed_at TEXT NOT NULL, + docid TEXT ); CREATE TABLE IF NOT EXISTS chunks ( @@ -95,6 +106,33 @@ impl Store { self.conn .execute_batch(SCHEMA) .context("failed to initialize schema")?; + self.migrate()?; + self.ensure_fts_table()?; + Ok(()) + } + + /// Run migrations for existing databases that may be missing newer columns. + fn migrate(&self) -> Result<()> { + // Check if docid column exists on files table. + let has_docid: bool = { + let mut stmt = self.conn.prepare("PRAGMA table_info(files)")?; + let rows = stmt.query_map([], |row| row.get::<_, String>(1))?; + let mut found = false; + for row in rows { + if row.as_deref() == Ok("docid") { + found = true; + break; + } + } + found + }; + if !has_docid { + self.conn + .execute_batch("ALTER TABLE files ADD COLUMN docid TEXT;")?; + } + // Always ensure the index exists (safe for both fresh and migrated DBs). + self.conn + .execute_batch("CREATE INDEX IF NOT EXISTS idx_files_docid ON files(docid);")?; Ok(()) } @@ -120,25 +158,38 @@ impl Store { // ── Files ─────────────────────────────────────────────────── - pub fn insert_file(&self, path: &str, hash: &str, mtime: i64, tags: &[String]) -> Result { + pub fn insert_file( + &self, + path: &str, + hash: &str, + mtime: i64, + tags: &[String], + docid: &str, + ) -> Result { let tags_json = serde_json::to_string(tags).unwrap_or_else(|_| "[]".into()); let now = chrono_now(); self.conn.execute( - "INSERT INTO files (path, content_hash, mtime, tags, indexed_at) - VALUES (?1, ?2, ?3, ?4, ?5) + "INSERT INTO files (path, content_hash, mtime, tags, indexed_at, docid) + VALUES (?1, ?2, ?3, ?4, ?5, ?6) ON CONFLICT(path) DO UPDATE SET content_hash = excluded.content_hash, mtime = excluded.mtime, tags = excluded.tags, - indexed_at = excluded.indexed_at", - params![path, hash, mtime, tags_json, now], + indexed_at = excluded.indexed_at, + docid = excluded.docid", + params![path, hash, mtime, tags_json, now, docid], + )?; + let file_id: i64 = self.conn.query_row( + "SELECT id FROM files WHERE path = ?1", + params![path], + |row| row.get(0), )?; - Ok(self.conn.last_insert_rowid()) + Ok(file_id) } pub fn get_file(&self, path: &str) -> Result> { let mut stmt = self.conn.prepare( - "SELECT id, path, content_hash, mtime, tags, indexed_at FROM files WHERE path = ?1", + "SELECT id, path, content_hash, mtime, tags, indexed_at, docid FROM files WHERE path = ?1", )?; let mut rows = stmt.query_map(params![path], |row| { Ok(FileRecord { @@ -148,6 +199,7 @@ impl Store { mtime: row.get(3)?, tags: parse_tags(&row.get::<_, String>(4)?), indexed_at: row.get(5)?, + docid: row.get(6)?, }) })?; match rows.next() { @@ -159,7 +211,7 @@ impl Store { pub fn get_all_files(&self) -> Result> { let mut stmt = self .conn - .prepare("SELECT id, path, content_hash, mtime, tags, indexed_at FROM files")?; + .prepare("SELECT id, path, content_hash, mtime, tags, indexed_at, docid FROM files")?; let rows = stmt.query_map([], |row| { Ok(FileRecord { id: row.get(0)?, @@ -168,6 +220,7 @@ impl Store { mtime: row.get(3)?, tags: parse_tags(&row.get::<_, String>(4)?), indexed_at: row.get(5)?, + docid: row.get(6)?, }) })?; let mut files = Vec::new(); @@ -358,6 +411,124 @@ impl Store { } } + /// Look up a file record by its row ID. + pub fn get_file_by_id(&self, file_id: i64) -> Result> { + let mut stmt = self.conn.prepare( + "SELECT id, path, content_hash, mtime, tags, indexed_at, docid FROM files WHERE id = ?1", + )?; + let mut rows = stmt.query_map(params![file_id], |row| { + Ok(FileRecord { + id: row.get(0)?, + path: row.get(1)?, + content_hash: row.get(2)?, + mtime: row.get(3)?, + tags: parse_tags(&row.get::<_, String>(4)?), + indexed_at: row.get(5)?, + docid: row.get(6)?, + }) + })?; + match rows.next() { + Some(rec) => Ok(Some(rec?)), + None => Ok(None), + } + } + + /// Look up a file by its 6-character docid. + pub fn get_file_by_docid(&self, docid: &str) -> Result> { + let mut stmt = self.conn.prepare( + "SELECT id, path, content_hash, mtime, tags, indexed_at, docid FROM files WHERE docid = ?1", + )?; + let mut rows = stmt.query_map(params![docid], |row| { + Ok(FileRecord { + id: row.get(0)?, + path: row.get(1)?, + content_hash: row.get(2)?, + mtime: row.get(3)?, + tags: parse_tags(&row.get::<_, String>(4)?), + indexed_at: row.get(5)?, + docid: row.get(6)?, + }) + })?; + match rows.next() { + Some(rec) => Ok(Some(rec?)), + None => Ok(None), + } + } + + // ── FTS5 ────────────────────────────────────────────────── + + /// Ensure the FTS5 virtual table exists. Called during init. + pub fn ensure_fts_table(&self) -> Result<()> { + self.conn + .execute_batch( + "CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5( + content, + file_id UNINDEXED, + chunk_seq UNINDEXED + );", + ) + .context("failed to create FTS5 virtual table")?; + Ok(()) + } + + /// Insert a chunk's text into the FTS5 table. + pub fn insert_fts_chunk(&self, file_id: i64, chunk_seq: i64, text: &str) -> Result<()> { + self.conn.execute( + "INSERT INTO chunks_fts (content, file_id, chunk_seq) VALUES (?1, ?2, ?3)", + params![text, file_id, chunk_seq], + )?; + Ok(()) + } + + /// Delete all FTS5 entries for a file. + pub fn delete_fts_chunks_for_file(&self, file_id: i64) -> Result<()> { + self.conn.execute( + "DELETE FROM chunks_fts WHERE file_id = ?1", + params![file_id], + )?; + Ok(()) + } + + /// Search the FTS5 index. Returns results ranked by BM25 score. + /// BM25 in SQLite returns negative values (more negative = better match), + /// so we negate them to get positive scores where higher = better. + /// + /// The query is wrapped in double quotes so that FTS5 treats it as a + /// phrase/literal rather than interpreting operators like `-`. + pub fn fts_search(&self, query: &str, limit: usize) -> Result> { + // Escape any double quotes in the query, then wrap in double quotes + // so FTS5 treats hyphens etc. as literal characters. + let escaped = query.replace('"', "\"\""); + let fts_query = format!("\"{}\"", escaped); + + let mut stmt = self.conn.prepare( + "SELECT file_id, chunk_seq, bm25(chunks_fts) as score, + snippet(chunks_fts, 0, '', '', '...', 64) + FROM chunks_fts + WHERE chunks_fts MATCH ?1 + ORDER BY score + LIMIT ?2", + )?; + + let rows = stmt.query_map(params![fts_query, limit as i64], |row| { + Ok(FtsResult { + file_id: row.get(0)?, + chunk_seq: row.get(1)?, + score: { + let raw: f64 = row.get(2)?; + -raw // negate: SQLite BM25 returns negative, more negative = better + }, + snippet: row.get(3)?, + }) + })?; + + let mut results = Vec::new(); + for row in rows { + results.push(row?); + } + Ok(results) + } + /// Return vector_ids for all chunks belonging to a file. /// Useful for tombstoning before re-indexing a changed file. pub fn get_vector_ids_for_file(&self, file_id: i64) -> Result> { @@ -394,6 +565,7 @@ fn chrono_now() -> String { #[cfg(test)] mod tests { use super::*; + use crate::docid::generate_docid; #[test] fn test_create_schema() { @@ -417,8 +589,9 @@ mod tests { fn test_insert_and_get_file() { let store = Store::open_memory().unwrap(); let tags = vec!["rust".to_string(), "programming".to_string()]; + let docid = generate_docid("notes/test.md"); let file_id = store - .insert_file("notes/test.md", "abc123", 1700000000, &tags) + .insert_file("notes/test.md", "abc123", 1700000000, &tags, &docid) .unwrap(); assert!(file_id > 0); @@ -427,13 +600,20 @@ mod tests { assert_eq!(rec.content_hash, "abc123"); assert_eq!(rec.mtime, 1700000000); assert_eq!(rec.tags, tags); + assert_eq!(rec.docid.unwrap(), docid); } #[test] fn test_insert_and_get_chunks() { let store = Store::open_memory().unwrap(); let file_id = store - .insert_file("notes/chunk_test.md", "hash1", 100, &[]) + .insert_file( + "notes/chunk_test.md", + "hash1", + 100, + &[], + &generate_docid("notes/chunk_test.md"), + ) .unwrap(); store @@ -457,7 +637,15 @@ mod tests { #[test] fn test_delete_file_cascades_chunks() { let store = Store::open_memory().unwrap(); - let file_id = store.insert_file("notes/del.md", "hash", 100, &[]).unwrap(); + let file_id = store + .insert_file( + "notes/del.md", + "hash", + 100, + &[], + &generate_docid("notes/del.md"), + ) + .unwrap(); store.insert_chunk(file_id, "H", "snippet", 10, 5).unwrap(); store .insert_chunk(file_id, "H2", "snippet2", 11, 6) @@ -497,8 +685,15 @@ mod tests { #[test] fn test_file_hash_changed() { let store = Store::open_memory().unwrap(); + let docid = generate_docid("notes/change.md"); let file_id = store - .insert_file("notes/change.md", "old_hash", 100, &["tag1".to_string()]) + .insert_file( + "notes/change.md", + "old_hash", + 100, + &["tag1".to_string()], + &docid, + ) .unwrap(); store.insert_chunk(file_id, "H", "text", 50, 10).unwrap(); store.insert_chunk(file_id, "H2", "text2", 51, 12).unwrap(); @@ -514,7 +709,13 @@ mod tests { store.delete_file(file_id).unwrap(); let new_file_id = store - .insert_file("notes/change.md", "new_hash", 200, &["tag1".to_string()]) + .insert_file( + "notes/change.md", + "new_hash", + 200, + &["tag1".to_string()], + &docid, + ) .unwrap(); store .insert_chunk(new_file_id, "H", "new text", 60, 15) @@ -556,4 +757,20 @@ mod tests { let st = store.stats().unwrap(); assert_eq!(st.vault_path.unwrap(), "/other/vault"); } + + #[test] + fn test_get_file_by_docid() { + let store = Store::open_memory().unwrap(); + let docid = generate_docid("notes/findme.md"); + store + .insert_file("notes/findme.md", "hash", 100, &[], &docid) + .unwrap(); + + let rec = store.get_file_by_docid(&docid).unwrap().unwrap(); + assert_eq!(rec.path, "notes/findme.md"); + assert_eq!(rec.docid.unwrap(), docid); + + // Non-existent docid returns None. + assert!(store.get_file_by_docid("ffffff").unwrap().is_none()); + } } diff --git a/tests/integration.rs b/tests/integration.rs index 0a13769..89923d1 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -7,6 +7,7 @@ use std::path::{Path, PathBuf}; use engraph::chunker::chunk_markdown; use engraph::config::Config; +use engraph::docid::generate_docid; use engraph::embedder::Embedder; use engraph::hnsw::HnswIndex; use engraph::indexer::{compute_file_hash, diff_vault, walk_vault}; @@ -100,7 +101,10 @@ fn index_vault(vault_path: &Path, data_dir: &Path, config: &Config, rebuild: boo let tags = parsed.tags; let chunks = parsed.chunks; - let file_id = store.insert_file(&rel_str, &hash, 0, &tags).unwrap(); + let docid = generate_docid(&rel_str); + let file_id = store + .insert_file(&rel_str, &hash, 0, &tags, &docid) + .unwrap(); for chunk in &chunks { let heading = chunk.heading.clone().unwrap_or_default();