Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/__tests__/search.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, it, expect } from "vitest";
import { dotProduct, searchEmbeddings } from "../search";
import type { EmbeddingRecord } from "../storage";
import { dotProduct, searchEmbeddings, searchPageEmbeddings } from "../search";
import type { EmbeddingRecord, PageEmbeddingRecord } from "../storage";

describe("dotProduct", () => {
it("computes dot product correctly", () => {
Expand Down Expand Up @@ -53,3 +53,30 @@ describe("searchEmbeddings", () => {
expect(results).toHaveLength(1);
});
});

describe("searchPageEmbeddings", () => {
const pages: PageEmbeddingRecord[] = [
{ pageId: 1, pageName: "Page A", embedding: [1, 0, 0], isJournal: false, blockCount: 5, timestamp: 0 },
{ pageId: 2, pageName: "Page B", embedding: [0, 1, 0], isJournal: true, blockCount: 3, timestamp: 0 },
{ pageId: 3, pageName: "Page C", embedding: [0.7, 0.7, 0], isJournal: false, blockCount: 2, timestamp: 0 },
];

it("returns page results sorted by similarity", () => {
const results = searchPageEmbeddings([1, 0, 0], pages, 10, 0);
expect(results[0].pageName).toBe("Page A");
expect(results[0].similarity).toBe(1);
expect(results[1].pageName).toBe("Page C");
});

it("includes isJournal flag", () => {
const results = searchPageEmbeddings([0, 1, 0], pages, 10, 0);
expect(results[0].pageName).toBe("Page B");
expect(results[0].isJournal).toBe(true);
});

it("respects topK and threshold", () => {
const results = searchPageEmbeddings([1, 0, 0], pages, 1, 0.5);
expect(results).toHaveLength(1);
expect(results[0].pageName).toBe("Page A");
});
});
25 changes: 25 additions & 0 deletions src/__tests__/storage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import {
getAllEmbeddings,
deleteEmbeddings,
clearAllEmbeddings,
putPageEmbeddings,
getAllPageEmbeddings,
clearAllPageEmbeddings,
getMetadata,
setMetadata,
getEmbeddingCount,
Expand Down Expand Up @@ -82,6 +85,28 @@ describe("embeddings CRUD", () => {
});
});

describe("page embeddings CRUD", () => {
it("stores and retrieves page embeddings", async () => {
await putPageEmbeddings([
{ pageId: 1, pageName: "Page A", embedding: [0.1, 0.2], isJournal: false, blockCount: 5, timestamp: Date.now() },
{ pageId: 2, pageName: "Page B", embedding: [0.3, 0.4], isJournal: true, blockCount: 3, timestamp: Date.now() },
]);
const all = await getAllPageEmbeddings();
expect(all).toHaveLength(2);
expect(all[0].pageName).toBe("Page A");
expect(all[1].isJournal).toBe(true);
});

it("clears all page embeddings", async () => {
await putPageEmbeddings([
{ pageId: 1, pageName: "Page A", embedding: [0.1], isJournal: false, blockCount: 2, timestamp: 0 },
]);
await clearAllPageEmbeddings();
const all = await getAllPageEmbeddings();
expect(all).toHaveLength(0);
});
});

describe("metadata", () => {
it("stores and retrieves metadata", async () => {
await setMetadata("model", "nomic-embed-text");
Expand Down
63 changes: 62 additions & 1 deletion src/indexer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import {
getAllEmbeddings,
deleteEmbeddings,
clearAllEmbeddings,
clearAllPageEmbeddings,
putPageEmbeddings,
getMetadata,
setMetadata,
getEmbeddingCount,
} from "./storage";
import type { PageEmbeddingRecord } from "./storage";
import { getSettings } from "./settings";

export interface IndexingState {
Expand Down Expand Up @@ -66,6 +69,7 @@ interface PageResult {
"original-name"?: string;
properties?: Record<string, any>;
"updated-at"?: number;
"journal?": boolean;
}

const SCHEMA_VERSION = 4;
Expand All @@ -74,6 +78,7 @@ interface PageInfo {
name: string;
properties: Record<string, any>;
updatedAt: number;
isJournal: boolean;
}

interface BlockInfo {
Expand Down Expand Up @@ -145,6 +150,7 @@ export async function indexBlocks(
const storedSchema = await getMetadata("schemaVersion");
if (!storedSchema || (storedSchema as number) < SCHEMA_VERSION) {
await clearAllEmbeddings();
await clearAllPageEmbeddings();
logseq.UI.showMsg("Embedding format changed, re-indexing all blocks...");
await setMetadata("schemaVersion", SCHEMA_VERSION);
}
Expand All @@ -153,6 +159,7 @@ export async function indexBlocks(
const storedModel = await getMetadata("model");
if (storedModel && storedModel !== settings.embeddingModel) {
await clearAllEmbeddings();
await clearAllPageEmbeddings();
logseq.UI.showMsg("Model changed, re-indexing all blocks...");
}
await setMetadata("model", settings.embeddingModel);
Expand All @@ -163,7 +170,7 @@ export async function indexBlocks(
const blockResults: BlockResult[][] = await logseq.DB.datascriptQuery(blockQuery);

// Bulk-fetch all pages
const pageQuery = `[:find (pull ?p [:db/id :block/name :block/original-name :block/properties :block/updated-at])
const pageQuery = `[:find (pull ?p [:db/id :block/name :block/original-name :block/properties :block/updated-at :block/journal?])
:where [?p :block/name _]]`;
const pageResults: PageResult[][] = await logseq.DB.datascriptQuery(pageQuery);

Expand Down Expand Up @@ -193,6 +200,7 @@ export async function indexBlocks(
name: page.originalName ?? page["original-name"] ?? page.name ?? "",
properties: page.properties ?? {},
updatedAt: page["updated-at"] ?? 0,
isJournal: page["journal?"] ?? false,
});
}

Expand Down Expand Up @@ -364,6 +372,59 @@ export async function indexBlocks(
await deleteEmbeddings(staleIds);
}

// Compute page centroids
if (!abort.signal.aborted) {
const allEmbs = staleIds.length > 0 ? await getAllEmbeddings() : allExisting;
const pageGroups = new Map<number, number[][]>();
for (const emb of allEmbs) {
let group = pageGroups.get(emb.pageId);
if (!group) {
group = [];
pageGroups.set(emb.pageId, group);
}
group.push(emb.embedding);
}

const pageRecords: PageEmbeddingRecord[] = [];
for (const [pageId, embeddings] of pageGroups) {
if (embeddings.length < 2) continue;
const pageInfo = pageMap.get(pageId);
if (!pageInfo) continue;

const dim = embeddings[0].length;
const centroid = new Array<number>(dim).fill(0);
for (const emb of embeddings) {
for (let i = 0; i < dim; i++) {
centroid[i] += emb[i];
}
}
let norm = 0;
for (let i = 0; i < dim; i++) {
centroid[i] /= embeddings.length;
norm += centroid[i] * centroid[i];
}
norm = Math.sqrt(norm);
if (norm > 0) {
for (let i = 0; i < dim; i++) {
centroid[i] /= norm;
}
}

pageRecords.push({
pageId,
pageName: pageInfo.name,
embedding: centroid,
isJournal: pageInfo.isJournal,
blockCount: embeddings.length,
timestamp: Date.now(),
});
}

if (pageRecords.length > 0) {
await putPageEmbeddings(pageRecords);
}
}

const count = await getEmbeddingCount();
await setMetadata("blockCount", count);
await setMetadata("lastIndexed", Date.now());
Expand Down
3 changes: 2 additions & 1 deletion src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import "@logseq/libs";
import { settingsSchema } from "./settings";
import { getSettings } from "./settings";
import { indexBlocks } from "./indexer";
import { setGraphName, clearAllEmbeddings } from "./storage";
import { setGraphName, clearAllEmbeddings, clearAllPageEmbeddings } from "./storage";
import { createSearchModal, showModal } from "./ui";

async function main() {
Expand Down Expand Up @@ -45,6 +45,7 @@ async function main() {
// Register rebuild command
logseq.App.registerCommandPalette({ key: "rebuild-index", label: "Semantic Search: Rebuild index" }, async () => {
await clearAllEmbeddings();
await clearAllPageEmbeddings();
logseq.UI.showMsg("Rebuilding index...");
try {
await indexBlocks();
Expand Down
31 changes: 30 additions & 1 deletion src/search.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import type { EmbeddingRecord } from "./storage";
import type { EmbeddingRecord, PageEmbeddingRecord } from "./storage";

export interface SearchResult {
blockId: string;
pageId: number;
similarity: number;
}

export interface PageSearchResult {
pageId: number;
pageName: string;
isJournal: boolean;
similarity: number;
}

export function dotProduct(a: number[], b: number[]): number {
let sum = 0;
for (let i = 0; i < a.length; i++) {
Expand Down Expand Up @@ -34,3 +41,25 @@ export function searchEmbeddings(
scored.sort((a, b) => b.similarity - a.similarity);
return scored.slice(0, topK);
}

export function searchPageEmbeddings(
queryEmbedding: number[],
records: PageEmbeddingRecord[],
topK: number,
threshold = 0.3,
): PageSearchResult[] {
const scored: PageSearchResult[] = [];
for (const record of records) {
const similarity = dotProduct(queryEmbedding, record.embedding);
if (similarity >= threshold) {
scored.push({
pageId: record.pageId,
pageName: record.pageName,
isJournal: record.isJournal,
similarity,
});
}
}
scored.sort((a, b) => b.similarity - a.similarity);
return scored.slice(0, topK);
}
56 changes: 55 additions & 1 deletion src/storage.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
const DB_PREFIX = "semantic-search-embeddings";
const DB_VERSION = 1;
const DB_VERSION = 2;
const EMBEDDINGS_STORE = "embeddings";
const METADATA_STORE = "metadata";
const PAGE_EMBEDDINGS_STORE = "pageEmbeddings";

let graphName = "";
let pageCache: PageEmbeddingRecord[] | null = null;

export function setGraphName(name: string): void {
graphName = name;
Expand All @@ -22,6 +24,15 @@ export interface EmbeddingRecord {
pageUpdatedAt: number;
}

export interface PageEmbeddingRecord {
pageId: number;
pageName: string;
embedding: number[];
isJournal: boolean;
blockCount: number;
timestamp: number;
}

export interface MetadataRecord {
key: string;
value: string | number;
Expand All @@ -38,6 +49,9 @@ function openDB(): Promise<IDBDatabase> {
if (!db.objectStoreNames.contains(METADATA_STORE)) {
db.createObjectStore(METADATA_STORE, { keyPath: "key" });
}
if (!db.objectStoreNames.contains(PAGE_EMBEDDINGS_STORE)) {
db.createObjectStore(PAGE_EMBEDDINGS_STORE, { keyPath: "pageId" });
}
};
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
Expand Down Expand Up @@ -135,6 +149,46 @@ export async function clearAllEmbeddings(): Promise<void> {
});
}

export async function putPageEmbeddings(records: PageEmbeddingRecord[]): Promise<void> {
const db = await openDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(PAGE_EMBEDDINGS_STORE, "readwrite");
const store = tx.objectStore(PAGE_EMBEDDINGS_STORE);
for (const record of records) {
store.put(record);
}
tx.oncomplete = () => { db.close(); pageCache = null; resolve(); };
tx.onerror = () => { db.close(); reject(tx.error); };
});
}

export async function getAllPageEmbeddings(): Promise<PageEmbeddingRecord[]> {
if (pageCache) return pageCache;
const db = await openDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(PAGE_EMBEDDINGS_STORE, "readonly");
const store = tx.objectStore(PAGE_EMBEDDINGS_STORE);
const req = store.getAll();
req.onsuccess = () => {
pageCache = req.result;
resolve(pageCache);
};
req.onerror = () => reject(req.error);
tx.oncomplete = () => db.close();
});
}

export async function clearAllPageEmbeddings(): Promise<void> {
const db = await openDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(PAGE_EMBEDDINGS_STORE, "readwrite");
const store = tx.objectStore(PAGE_EMBEDDINGS_STORE);
store.clear();
tx.oncomplete = () => { db.close(); pageCache = []; resolve(); };
tx.onerror = () => { db.close(); reject(tx.error); };
});
}

export async function getEmbeddingCount(): Promise<number> {
const db = await openDB();
return new Promise((resolve, reject) => {
Expand Down
Loading
Loading