From 05241c4bf20b57e725327419e073db6d7eae219b Mon Sep 17 00:00:00 2001 From: ndjama Date: Tue, 2 Jun 2026 07:50:26 +0200 Subject: [PATCH] feat(core): port server/tools_viz.py to backend-neutral queries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sixth step of the Kuzu -> DuckDB migration. The Mermaid / DOT diagram generators in tools_viz.py and the graph_stats / live_graph_stats MCP tools no longer emit Cypher. Both backends produce identical diagram input and identical node counts. What lands: - codegraph/core/protocol.py: two helper additions on GraphDB. - find_neighbors gets a limit= parameter for the bounded-traversal patterns in the viz tools (call graph, class hierarchy, etc. all take a max_nodes hint). - count_nodes(label, where=...) for the 6 "count of nodes" sites in graph_stats and live_graph_stats — single COUNT(*) per call on DuckDB, one cheap MATCH count() on Kuzu. - codegraph/server/tools_viz.py: every conn.execute() against the graph DB gone. _rows import gone. - _viz_file_imports: two find_neighbors calls (outgoing + incoming) replacing the UNION ALL pattern. Filtered or unfiltered via the src_key / no-anchor distinction. - _viz_call_graph and _viz_class_hierarchy: when symbol_name is supplied, the Cypher "WHERE caller.name = $n OR callee.name = $n" becomes two filtered find_neighbors calls unioned in Python. Cheap because LIMIT bounds each side. - _viz_file_symbols: three find_nodes calls (Function / Class / MdSection by file_path) with order_by=['start_line']. - _viz_doc_structure: one find_nodes call with file_path filter + order_by + limit. - _viz_full_overview: file aggregation by language done in Python (Counter) since GROUP BY isn't a protocol primitive yet. Top files by symbol density similarly aggregated from find_neighbors over DEFINES_FN edges. Both small-N enough that the Python pass is negligible. - graph_stats / live_graph_stats: 6 conn.execute count() calls each become count_nodes() helper calls. The SQLite FTS count in live_graph_stats stays as direct fts_conn .execute() — that's a different database and out of scope for the graph backend swap. Test count: 298 (unchanged; no new tests, the existing 298-test suite already covers find_neighbors with limit + count semantics through the helper parity sweep). End-to-end smoke on a 3-file Python repo: Kuzu: Files=5 Functions=2 Classes=2 IMPORTS=2 INHERITS=1 DuckDB: Files=5 Functions=2 Classes=2 IMPORTS=2 INHERITS=1 What's left for the migration: - cli/commands_monitor.py (12 stats queries — task #10) - federation.py (kuzu-direct child DB opens — task #11) --- codegraph/core/db_duckdb.py | 21 ++- codegraph/core/db_kuzu.py | 26 +++- codegraph/core/protocol.py | 9 ++ codegraph/server/tools_viz.py | 275 ++++++++++++++++++++-------------- 4 files changed, 214 insertions(+), 117 deletions(-) diff --git a/codegraph/core/db_duckdb.py b/codegraph/core/db_duckdb.py index 318de83..d21c410 100644 --- a/codegraph/core/db_duckdb.py +++ b/codegraph/core/db_duckdb.py @@ -300,6 +300,23 @@ def find_nodes( cols = [d[0] for d in (cursor.description or [])] return [dict(zip(cols, row)) for row in cursor.fetchall()] + def count_nodes(self, label: str, where: dict[str, Any] | None = None) -> int: + from codegraph.core.graph_model import NODES + + if label not in NODES: + raise ValueError(f"Unknown node label: {label!r}") + spec = NODES[label] + params: list[Any] = [] + clauses: list[str] = [] + if where: + for field, value in where.items(): + clauses.append(f"{field} = ?") + params.append(value) + where_clause = "WHERE " + " AND ".join(clauses) if clauses else "" + sql = f"SELECT count(*) FROM {spec.table} {where_clause}" + row = self._conn.execute(sql, params).fetchone() + return int(row[0]) if row else 0 + def find_nodes_without_incoming( self, label: str, @@ -356,6 +373,7 @@ def find_neighbors( return_src: list[str] | None = None, return_dst: list[str] | None = None, return_edge: list[str] | None = None, + limit: int | None = None, ) -> list[dict[str, Any]]: from codegraph.core.graph_model import EDGES, NODES @@ -396,9 +414,10 @@ def find_neighbors( select_clause = ", ".join(select_parts) where_clause = "WHERE " + " AND ".join(clauses) + limit_clause = f"LIMIT {int(limit)}" if limit else "" sql = ( f"SELECT {select_clause} FROM {edge.table} e, " - f"{src.table} a, {dst.table} b {where_clause}" + f"{src.table} a, {dst.table} b {where_clause} {limit_clause}" ) cursor = self._conn.execute(sql, params) diff --git a/codegraph/core/db_kuzu.py b/codegraph/core/db_kuzu.py index 6afba62..fbff4c9 100644 --- a/codegraph/core/db_kuzu.py +++ b/codegraph/core/db_kuzu.py @@ -267,6 +267,25 @@ def find_nodes( out.append({c.removeprefix("n."): v for c, v in zip(cols, row)}) return out + def count_nodes(self, label: str, where: dict[str, Any] | None = None) -> int: + from codegraph.core.graph_model import NODES + + if label not in NODES: + raise ValueError(f"Unknown node label: {label!r}") + params: dict[str, Any] = {} + clauses: list[str] = [] + if where: + for i, (field, value) in enumerate(where.items()): + bind = f"_w{i}" + clauses.append(f"n.{field} = ${bind}") + params[bind] = value + where_clause = "WHERE " + " AND ".join(clauses) if clauses else "" + cypher = f"MATCH (n:{label}) {where_clause} RETURN count(n) AS c" + result = self._inner.execute(cypher, params) if params else self._inner.execute(cypher) + if result.has_next(): + return int(result.get_next()[0]) + return 0 + def find_nodes_without_incoming( self, label: str, @@ -318,6 +337,7 @@ def find_neighbors( return_src: list[str] | None = None, return_dst: list[str] | None = None, return_edge: list[str] | None = None, + limit: int | None = None, ) -> list[dict[str, Any]]: from codegraph.core.graph_model import EDGES, NODES @@ -369,7 +389,11 @@ def find_neighbors( match_clause = "".join(match_parts) where_clause = "WHERE " + " AND ".join(clauses) if clauses else "" return_clause = ", ".join(return_parts) - cypher = f"MATCH {match_clause} {where_clause} RETURN {return_clause}" + limit_clause = f"LIMIT {int(limit)}" if limit else "" + cypher = ( + f"MATCH {match_clause} {where_clause} " + f"RETURN {return_clause} {limit_clause}" + ) result = self._inner.execute(cypher, params) if params else self._inner.execute(cypher) cols = result.get_column_names() diff --git a/codegraph/core/protocol.py b/codegraph/core/protocol.py index 190570b..0d4ee8c 100644 --- a/codegraph/core/protocol.py +++ b/codegraph/core/protocol.py @@ -188,17 +188,26 @@ def find_neighbors( return_src: list[str] | None = None, return_dst: list[str] | None = None, return_edge: list[str] | None = None, + limit: int | None = None, ) -> list[dict[str, Any]]: """Walk an edge type with optional anchoring on either side. Result dicts use prefixed keys: ``src_``, ``dst_``, ``edge_``. The caller picks which prefixed fields they want via the three ``return_*`` lists; the rest are dropped. + ``limit`` caps the row count. Used by find_callers, find_callees, imports_of, who_imports, and similar "walk an edge from / to a known anchor" tools. """ + def count_nodes(self, label: str, where: dict[str, Any] | None = None) -> int: + """Count nodes of ``label`` matching the optional ``where`` filter. + + Used by graph_stats and friends — cheaper than fetching every + row just to call ``len()``. + """ + def reach_via_edge( self, edge_type: str, diff --git a/codegraph/server/tools_viz.py b/codegraph/server/tools_viz.py index 89dd2bd..9a71bb2 100644 --- a/codegraph/server/tools_viz.py +++ b/codegraph/server/tools_viz.py @@ -13,7 +13,6 @@ import json import os -from codegraph.core.utils import rows as _rows from codegraph.core.utils import safe_id as _safe_id @@ -30,21 +29,31 @@ def _viz_file_imports(conn, file_path: str, max_nodes: int, fmt: str) -> str: if file_path: if not os.path.isabs(file_path) and _srv._root: file_path = str(_srv._root / file_path) - r = conn.execute( - """MATCH (src:File {path:$p})-[:IMPORTS]->(dep:File) - RETURN src.path AS src, dep.path AS tgt - UNION ALL - MATCH (up:File)-[:IMPORTS]->(src:File {path:$p}) - RETURN up.path AS src, src.path AS tgt""", - {"p": file_path}, + # Outgoing imports (file imports X) + incoming imports (X imports file) + outgoing = conn.find_neighbors( + "IMPORTS", + src_key=file_path, + return_src=["path"], + return_dst=["path"], ) - else: - r = conn.execute( - f"MATCH (src:File)-[:IMPORTS]->(tgt:File) " - f"RETURN src.path AS src, tgt.path AS tgt LIMIT {max_nodes * 2}", + incoming = conn.find_neighbors( + "IMPORTS", + dst_key=file_path, + return_src=["path"], + return_dst=["path"], ) + rows = [{"src": r["src_path"], "tgt": r["dst_path"]} for r in outgoing + incoming] + else: + rows = [ + {"src": r["src_path"], "tgt": r["dst_path"]} + for r in conn.find_neighbors( + "IMPORTS", + return_src=["path"], + return_dst=["path"], + limit=max_nodes * 2, + ) + ] - rows = _rows(r) if not rows: return "graph LR\n NO_IMPORTS[No import edges found]" @@ -74,23 +83,41 @@ def _viz_file_imports(conn, file_path: str, max_nodes: int, fmt: str) -> str: return "\n".join(lines) def _viz_call_graph(conn, symbol_name: str, max_nodes: int, fmt: str) -> str: + return_args = dict( + return_src=["name", "file_path"], + return_dst=["name", "file_path"], + ) if symbol_name: - r = conn.execute( - """MATCH (caller:Function)-[:CALLS]->(callee:Function) - WHERE caller.name = $n OR callee.name = $n - RETURN caller.name AS src, callee.name AS tgt, - caller.file_path AS src_file, callee.file_path AS tgt_file - LIMIT $lim""", - {"n": symbol_name, "lim": max_nodes * 2}, + # Cypher's "WHERE caller.name = $n OR callee.name = $n" → two + # filtered queries unioned in Python. Most names point to a + # handful of edges, so the duplication is cheap. + edges_a = conn.find_neighbors( + "CALLS", + src_where={"name": symbol_name}, + **return_args, + limit=max_nodes * 2, + ) + edges_b = conn.find_neighbors( + "CALLS", + dst_where={"name": symbol_name}, + **return_args, + limit=max_nodes * 2, ) + edges = edges_a + edges_b else: - r = conn.execute( - f"MATCH (caller:Function)-[:CALLS]->(callee:Function) " - f"RETURN caller.name AS src, callee.name AS tgt, " - f"caller.file_path AS src_file, callee.file_path AS tgt_file LIMIT {max_nodes * 2}", + edges = conn.find_neighbors( + "CALLS", **return_args, limit=max_nodes * 2, ) - rows = _rows(r) + rows = [ + { + "src": r["src_name"], + "tgt": r["dst_name"], + "src_file": r["src_file_path"], + "tgt_file": r["dst_file_path"], + } + for r in edges + ] if not rows: return "graph LR\n NO_CALLS[No call edges found]" @@ -118,23 +145,37 @@ def _viz_call_graph(conn, symbol_name: str, max_nodes: int, fmt: str) -> str: return "\n".join(lines) def _viz_class_hierarchy(conn, symbol_name: str, max_nodes: int, fmt: str) -> str: + return_args = dict( + return_src=["name", "file_path"], + return_dst=["name"], + ) if symbol_name: - r = conn.execute( - """MATCH (child:Class)-[:INHERITS]->(parent:Class) - WHERE child.name = $n OR parent.name = $n - RETURN child.name AS child, parent.name AS parent, - child.file_path AS child_file - LIMIT $lim""", - {"n": symbol_name, "lim": max_nodes}, + edges_a = conn.find_neighbors( + "INHERITS", + src_where={"name": symbol_name}, + **return_args, + limit=max_nodes, ) + edges_b = conn.find_neighbors( + "INHERITS", + dst_where={"name": symbol_name}, + **return_args, + limit=max_nodes, + ) + edges = edges_a + edges_b else: - r = conn.execute( - f"MATCH (child:Class)-[:INHERITS]->(parent:Class) " - f"RETURN child.name AS child, parent.name AS parent, " - f"child.file_path AS child_file LIMIT {max_nodes}", + edges = conn.find_neighbors( + "INHERITS", **return_args, limit=max_nodes, ) - rows = _rows(r) + rows = [ + { + "child": r["src_name"], + "parent": r["dst_name"], + "child_file": r["src_file_path"], + } + for r in edges + ] if not rows: return "graph BT\n NO_INHERITANCE[No inheritance edges found]" @@ -170,41 +211,39 @@ def _viz_file_symbols(conn, file_path: str, fmt: str) -> str: short = _short_path(file_path) file_id = _safe_id(short) - # Functions - r_fn = conn.execute( - "MATCH (fn:Function) WHERE fn.file_path = $p RETURN fn.name, fn.start_line ORDER BY fn.start_line", - {"p": file_path}, + fns = conn.find_nodes( + "Function", + where={"file_path": file_path}, + return_fields=["name", "start_line"], + order_by=["start_line"], ) - fns = _rows(r_fn) - - # Classes - r_cls = conn.execute( - "MATCH (c:Class) WHERE c.file_path = $p RETURN c.name, c.start_line ORDER BY c.start_line", - {"p": file_path}, + classes = conn.find_nodes( + "Class", + where={"file_path": file_path}, + return_fields=["name", "start_line"], + order_by=["start_line"], ) - classes = _rows(r_cls) - - # MdSections - r_md = conn.execute( - "MATCH (s:MdSection) WHERE s.file_path = $p RETURN s.title, s.level, s.start_line ORDER BY s.start_line", - {"p": file_path}, + sections = conn.find_nodes( + "MdSection", + where={"file_path": file_path}, + return_fields=["title", "level", "start_line"], + order_by=["start_line"], ) - sections = _rows(r_md) if fmt == "mermaid": lines = ["graph TD", f' {file_id}["{short}"]:::file'] for cls in classes: - cls_id = _safe_id(f"cls_{cls['c.name']}") - lines.append(f' {cls_id}["{cls["c.name"]} (L{cls["c.start_line"]})"]:::class') + cls_id = _safe_id(f"cls_{cls['name']}") + lines.append(f' {cls_id}["{cls["name"]} (L{cls["start_line"]})"]:::class') lines.append(f" {file_id} --> {cls_id}") for fn in fns: - fn_id = _safe_id(f"fn_{fn['fn.name']}_{fn['fn.start_line']}") - lines.append(f' {fn_id}["{fn["fn.name"]}() L{fn["fn.start_line"]}"]:::func') + fn_id = _safe_id(f"fn_{fn['name']}_{fn['start_line']}") + lines.append(f' {fn_id}["{fn["name"]}() L{fn["start_line"]}"]:::func') lines.append(f" {file_id} --> {fn_id}") for sec in sections: - sec_id = _safe_id(f"sec_{sec['s.title']}_{sec['s.start_line']}") - prefix = "#" * sec["s.level"] - lines.append(f' {sec_id}["{prefix} {sec["s.title"]} L{sec["s.start_line"]}"]:::doc') + sec_id = _safe_id(f"sec_{sec['title']}_{sec['start_line']}") + prefix = "#" * sec["level"] + lines.append(f' {sec_id}["{prefix} {sec["title"]} L{sec["start_line"]}"]:::doc') lines.append(f" {file_id} --> {sec_id}") lines.append(" classDef file fill:#e1f5fe,stroke:#0288d1") lines.append(" classDef class fill:#fff3e0,stroke:#f57c00") @@ -214,32 +253,34 @@ def _viz_file_symbols(conn, file_path: str, fmt: str) -> str: else: lines = ["digraph file_symbols {", " rankdir=TD;", f' "{short}" [shape=folder];'] for cls in classes: - lines.append(f' "{cls["c.name"]}" [shape=box,style=filled,fillcolor=lightyellow];') - lines.append(f' "{short}" -> "{cls["c.name"]}";') + lines.append(f' "{cls["name"]}" [shape=box,style=filled,fillcolor=lightyellow];') + lines.append(f' "{short}" -> "{cls["name"]}";') for fn in fns: - lines.append(f' "{fn["fn.name"]}" [shape=ellipse];') - lines.append(f' "{short}" -> "{fn["fn.name"]}";') + lines.append(f' "{fn["name"]}" [shape=ellipse];') + lines.append(f' "{short}" -> "{fn["name"]}";') lines.append("}") return "\n".join(lines) def _viz_doc_structure(conn, file_path: str, max_nodes: int, fmt: str) -> str: + common_fields = ["id", "title", "level", "start_line", "file_path"] if file_path: if not os.path.isabs(file_path) and _srv._root: file_path = str(_srv._root / file_path) - r = conn.execute( - "MATCH (s:MdSection) WHERE s.file_path = $p " - "RETURN s.id, s.title, s.level, s.start_line, s.file_path " - "ORDER BY s.start_line LIMIT $lim", - {"p": file_path, "lim": max_nodes}, + rows = conn.find_nodes( + "MdSection", + where={"file_path": file_path}, + return_fields=common_fields, + order_by=["start_line"], + limit=max_nodes, ) else: - r = conn.execute( - f"MATCH (s:MdSection) " - f"RETURN s.id, s.title, s.level, s.start_line, s.file_path " - f"ORDER BY s.file_path, s.start_line LIMIT {max_nodes}", + rows = conn.find_nodes( + "MdSection", + return_fields=common_fields, + order_by=["file_path", "start_line"], + limit=max_nodes, ) - rows = _rows(r) if not rows: return "graph TD\n NO_DOCS[No markdown sections found]" @@ -247,7 +288,7 @@ def _viz_doc_structure(conn, file_path: str, max_nodes: int, fmt: str) -> str: lines = ["graph TD"] by_file: dict[str, list] = {} for row in rows: - fp = _short_path(row["s.file_path"]) + fp = _short_path(row["file_path"]) by_file.setdefault(fp, []).append(row) for fp, secs in by_file.items(): @@ -255,11 +296,11 @@ def _viz_doc_structure(conn, file_path: str, max_nodes: int, fmt: str) -> str: lines.append(f' {fp_id}["{fp}"]:::file') prev_by_level: dict[int, str] = {} for sec in secs: - sec_id = _safe_id(f"s_{sec['s.start_line']}_{fp}") - prefix = "#" * sec["s.level"] - lines.append(f' {sec_id}["{prefix} {sec["s.title"]}"]:::h{min(sec["s.level"], 3)}') + sec_id = _safe_id(f"s_{sec['start_line']}_{fp}") + prefix = "#" * sec["level"] + lines.append(f' {sec_id}["{prefix} {sec["title"]}"]:::h{min(sec["level"], 3)}') parent_id = None - for lvl in range(sec["s.level"] - 1, 0, -1): + for lvl in range(sec["level"] - 1, 0, -1): if lvl in prev_by_level: parent_id = prev_by_level[lvl] break @@ -267,7 +308,7 @@ def _viz_doc_structure(conn, file_path: str, max_nodes: int, fmt: str) -> str: lines.append(f" {parent_id} --> {sec_id}") else: lines.append(f" {fp_id} --> {sec_id}") - prev_by_level[sec["s.level"]] = sec_id + prev_by_level[sec["level"]] = sec_id lines.append(" classDef file fill:#e1f5fe,stroke:#0288d1") lines.append(" classDef h1 fill:#c8e6c9,stroke:#2e7d32,stroke-width:2px") @@ -277,35 +318,45 @@ def _viz_doc_structure(conn, file_path: str, max_nodes: int, fmt: str) -> str: else: lines = ["digraph docs {", " rankdir=TD;"] for row in rows: - fp = _short_path(row["s.file_path"]) - lines.append(f' "{row["s.title"]}" [label="{"#" * row["s.level"]} {row["s.title"]}"];') + lines.append( + f' "{row["title"]}" [label="{"#" * row["level"]} {row["title"]}"];' + ) lines.append("}") return "\n".join(lines) def _viz_full_overview(conn, max_nodes: int, fmt: str) -> str: """High-level codebase overview: files grouped by language with counts.""" - r = conn.execute( - "MATCH (f:File) RETURN f.lang AS lang, count(f) AS cnt ORDER BY cnt DESC", - ) - lang_stats = _rows(r) - - fn_rows = _rows(conn.execute("MATCH (n:Function) RETURN count(n) AS c")) - fn_count = fn_rows[0]["c"] if fn_rows else 0 - - cls_rows = _rows(conn.execute("MATCH (n:Class) RETURN count(n) AS c")) - cls_count = cls_rows[0]["c"] if cls_rows else 0 - - md_rows = _rows(conn.execute("MATCH (n:MdSection) RETURN count(n) AS c")) - md_count = md_rows[0]["c"] if md_rows else 0 - - # Top files by symbol density - r5 = conn.execute( - "MATCH (f:File)-[:DEFINES_FN]->(fn:Function) " - "RETURN f.path AS file, f.lang AS lang, count(fn) AS fn_count " - "ORDER BY fn_count DESC LIMIT $lim", - {"lim": max_nodes}, + # Aggregate "files per language" in Python — small N, no benefit from + # GROUP BY at the backend. + from collections import Counter + + files = conn.find_nodes("File", return_fields=["path", "lang"]) + lang_counts = Counter(f.get("lang") for f in files) + lang_stats = [ + {"lang": lang, "cnt": cnt} + for lang, cnt in sorted(lang_counts.items(), key=lambda kv: -kv[1]) + ] + + fn_count = conn.count_nodes("Function") + cls_count = conn.count_nodes("Class") + md_count = conn.count_nodes("MdSection") + + # Top files by symbol density: count DEFINES_FN edges per file. + # find_neighbors gives us (file_path, function_id) pairs; tally. + defines = conn.find_neighbors( + "DEFINES_FN", return_src=["path", "lang"] ) - top_files = _rows(r5) + fn_per_file: dict[tuple[str, str | None], int] = {} + for r in defines: + key = (r["src_path"], r.get("src_lang")) + fn_per_file[key] = fn_per_file.get(key, 0) + 1 + top_files = sorted( + ( + {"file": path, "lang": lang, "fn_count": cnt} + for (path, lang), cnt in fn_per_file.items() + ), + key=lambda r: -r["fn_count"], + )[:max_nodes] if fmt == "mermaid": lines = ["graph TD"] @@ -415,13 +466,10 @@ def graph_stats() -> str: Useful to confirm the index is populated. """ conn = _get_conn() - stats = {} - for label in ("File", "Function", "Class", "TFResource", "TFVar", "MdSection"): - # Kuzu Cypher requires literal labels — safe: fixed allowlist - query = "MATCH (n:" + label + ") RETURN count(n) AS c" - r = conn.execute(query) - rows = _rows(r) - stats[label] = rows[0]["c"] if rows else 0 + stats = { + label: conn.count_nodes(label) + for label in ("File", "Function", "Class", "TFResource", "TFVar", "MdSection") + } return json.dumps(stats, indent=2) @mcp.tool() @@ -443,10 +491,7 @@ def live_graph_stats() -> str: nodes: dict[str, int] = {} for label in ("File", "Function", "Class", "TFResource", "TFVar", "MdSection"): try: - query = "MATCH (n:" + label + ") RETURN count(n) AS c" - r = conn.execute(query) - rows = _rows(r) - nodes[label] = rows[0]["c"] if rows else 0 + nodes[label] = conn.count_nodes(label) except Exception: nodes[label] = 0