Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion codegraph/core/db_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,23 @@ def find_nodes(
cols = [d[0] for d in (cursor.description or [])]
return [dict(zip(cols, row)) for row in cursor.fetchall()]

def count_nodes(self, label: str, where: dict[str, Any] | None = None) -> int:
from codegraph.core.graph_model import NODES

if label not in NODES:
raise ValueError(f"Unknown node label: {label!r}")
spec = NODES[label]
params: list[Any] = []
clauses: list[str] = []
if where:
for field, value in where.items():
clauses.append(f"{field} = ?")
params.append(value)
where_clause = "WHERE " + " AND ".join(clauses) if clauses else ""
sql = f"SELECT count(*) FROM {spec.table} {where_clause}"
row = self._conn.execute(sql, params).fetchone()
return int(row[0]) if row else 0

def find_nodes_without_incoming(
self,
label: str,
Expand Down Expand Up @@ -356,6 +373,7 @@ def find_neighbors(
return_src: list[str] | None = None,
return_dst: list[str] | None = None,
return_edge: list[str] | None = None,
limit: int | None = None,
) -> list[dict[str, Any]]:
from codegraph.core.graph_model import EDGES, NODES

Expand Down Expand Up @@ -396,9 +414,10 @@ def find_neighbors(

select_clause = ", ".join(select_parts)
where_clause = "WHERE " + " AND ".join(clauses)
limit_clause = f"LIMIT {int(limit)}" if limit else ""
sql = (
f"SELECT {select_clause} FROM {edge.table} e, "
f"{src.table} a, {dst.table} b {where_clause}"
f"{src.table} a, {dst.table} b {where_clause} {limit_clause}"
)

cursor = self._conn.execute(sql, params)
Expand Down
26 changes: 25 additions & 1 deletion codegraph/core/db_kuzu.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,25 @@ def find_nodes(
out.append({c.removeprefix("n."): v for c, v in zip(cols, row)})
return out

def count_nodes(self, label: str, where: dict[str, Any] | None = None) -> int:
from codegraph.core.graph_model import NODES

if label not in NODES:
raise ValueError(f"Unknown node label: {label!r}")
params: dict[str, Any] = {}
clauses: list[str] = []
if where:
for i, (field, value) in enumerate(where.items()):
bind = f"_w{i}"
clauses.append(f"n.{field} = ${bind}")
params[bind] = value
where_clause = "WHERE " + " AND ".join(clauses) if clauses else ""
cypher = f"MATCH (n:{label}) {where_clause} RETURN count(n) AS c"
result = self._inner.execute(cypher, params) if params else self._inner.execute(cypher)
if result.has_next():
return int(result.get_next()[0])
return 0

def find_nodes_without_incoming(
self,
label: str,
Expand Down Expand Up @@ -318,6 +337,7 @@ def find_neighbors(
return_src: list[str] | None = None,
return_dst: list[str] | None = None,
return_edge: list[str] | None = None,
limit: int | None = None,
) -> list[dict[str, Any]]:
from codegraph.core.graph_model import EDGES, NODES

Expand Down Expand Up @@ -369,7 +389,11 @@ def find_neighbors(
match_clause = "".join(match_parts)
where_clause = "WHERE " + " AND ".join(clauses) if clauses else ""
return_clause = ", ".join(return_parts)
cypher = f"MATCH {match_clause} {where_clause} RETURN {return_clause}"
limit_clause = f"LIMIT {int(limit)}" if limit else ""
cypher = (
f"MATCH {match_clause} {where_clause} "
f"RETURN {return_clause} {limit_clause}"
)

result = self._inner.execute(cypher, params) if params else self._inner.execute(cypher)
cols = result.get_column_names()
Expand Down
9 changes: 9 additions & 0 deletions codegraph/core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,26 @@ def find_neighbors(
return_src: list[str] | None = None,
return_dst: list[str] | None = None,
return_edge: list[str] | None = None,
limit: int | None = None,
) -> list[dict[str, Any]]:
"""Walk an edge type with optional anchoring on either side.

Result dicts use prefixed keys: ``src_<field>``, ``dst_<field>``,
``edge_<field>``. The caller picks which prefixed fields they
want via the three ``return_*`` lists; the rest are dropped.
``limit`` caps the row count.

Used by find_callers, find_callees, imports_of, who_imports,
and similar "walk an edge from / to a known anchor" tools.
"""

def count_nodes(self, label: str, where: dict[str, Any] | None = None) -> int:
"""Count nodes of ``label`` matching the optional ``where`` filter.

Used by graph_stats and friends — cheaper than fetching every
row just to call ``len()``.
"""

def reach_via_edge(
self,
edge_type: str,
Expand Down
Loading
Loading