diff --git a/topoexplorer/d3_graph_html.py b/topoexplorer/d3_graph_html.py
index 2c5b6a68..3809c5b2 100644
--- a/topoexplorer/d3_graph_html.py
+++ b/topoexplorer/d3_graph_html.py
@@ -83,6 +83,7 @@ def build_standalone_d3_html(
header {{ padding: 12px 16px; background: #fff; border-bottom: 1px solid #ccc; }}
header h1 {{ margin: 0; font-size: 1.1rem; }}
header p {{ margin: 4px 0 0; color: #555; font-size: 0.85rem; }}
+ header p:empty {{ display: none; margin: 0; }}
#chart-wrap {{ position: relative; {chart_size_css} }}
#chart {{ width: 100%; height: 100%; min-height: inherit; }}
#err {{ color: #b00; padding: 16px; white-space: pre-wrap; font-family: monospace; }}
@@ -105,6 +106,20 @@ def build_standalone_d3_html(
#legend .legend-dot {{ display: inline-block; width: 12px; height: 12px;
border-radius: 50%; border: 1px solid rgba(0,0,0,0.18); flex-shrink: 0; }}
#legend .legend-label {{ color: #222; }}
+ #metrics-hud {{
+ display: none; position: absolute; top: 10px; left: 10px; z-index: 5;
+ background: rgba(255, 255, 255, 0.92); border: 1px solid #d0d0d0;
+ border-radius: 6px; padding: 8px 10px; font-size: 12px; line-height: 1.4;
+ box-shadow: 0 1px 4px rgba(0,0,0,0.08); pointer-events: none;
+ max-width: min(240px, 60%);
+ }}
+ #metrics-hud .hud-title {{ display: block; font-weight: 600; font-size: 11px;
+ color: #555; letter-spacing: 0.02em; text-transform: uppercase;
+ margin-bottom: 6px; }}
+ #metrics-hud .hud-row {{ display: flex; justify-content: space-between;
+ gap: 12px; margin: 2px 0; }}
+ #metrics-hud .hud-label {{ color: #555; }}
+ #metrics-hud .hud-value {{ color: #111; font-variant-numeric: tabular-nums; }}
@@ -114,6 +129,7 @@ def build_standalone_d3_html(
@@ -127,6 +143,20 @@ def build_standalone_d3_html(
function showError(msg) {
document.getElementById("err").textContent = msg;
}
+ function metricLines(obj) {
+ if (!obj) return "";
+ var parts = [];
+ Object.keys(obj).forEach(function(k) { parts.push(k + ": " + obj[k]); });
+ return parts.length ? ("\\n" + parts.join("\\n")) : "";
+ }
+ function nodeTooltip(d) {
+ return (d.label || d.id) + " — degree " + (d.degree || 0) + metricLines(d.metrics);
+ }
+ function edgeTooltip(l) {
+ var s = (l.source && l.source.id) ? l.source.id : l.source;
+ var t = (l.target && l.target.id) ? l.target.id : l.target;
+ return s + " — " + t + metricLines(l.metrics);
+ }
try {
const raw = document.getElementById("graph-payload").textContent;
const payload = JSON.parse(raw);
@@ -154,6 +184,26 @@ def build_standalone_d3_html(
box.style.display = "block";
})();
+ (function renderMetricsHud() {
+ var box = document.getElementById("metrics-hud");
+ if (!box) return;
+ var rows = payload.graphMetrics || [];
+ if (!payload.showMetricsHud || rows.length === 0) {
+ box.innerHTML = "";
+ box.style.display = "none";
+ return;
+ }
+ var html = 'Displayed graph';
+ rows.forEach(function(r) {
+ html += ''
+ + '' + r.label + ''
+ + '' + r.value + ''
+ + '
';
+ });
+ box.innerHTML = html;
+ box.style.display = "block";
+ })();
+
const nodes = (payload.nodes || []).map(function(d) {
const o = Object.assign({}, d);
o.id = String(o.id);
@@ -190,7 +240,8 @@ def build_standalone_d3_html(
return {
source: String(l.source),
target: String(l.target),
- color: l.color
+ color: l.color,
+ metrics: l.metrics
};
});
const graph3d = ForceGraph3D()(chart)
@@ -199,9 +250,10 @@ def build_standalone_d3_html(
.backgroundColor("#fafafa")
.nodeRelSize(4)
.nodeColor(function(n) { return n.color || "#666"; })
- .nodeLabel(function(n) { return (n.label || n.id) + " — degree " + (n.degree || 0); })
+ .nodeLabel(function(n) { return nodeTooltip(n); })
.nodeVal(function(n) { return 1 + Math.log1p(n.degree || 1); })
.linkColor(function(l) { return l.color || "#888"; })
+ .linkLabel(function(l) { return edgeTooltip(l); })
.linkOpacity(0.6)
.linkWidth(0.6)
.graphData({ nodes: nodes3d, links: links3d });
@@ -234,7 +286,8 @@ def build_standalone_d3_html(
return {
source: s,
target: t,
- color: l.color
+ color: l.color,
+ metrics: l.metrics
};
}).filter(function(l) { return l.source && l.target; });
@@ -320,6 +373,8 @@ def build_standalone_d3_html(
.attr("stroke", function(d) { return d.color || "#888"; })
.attr("stroke-width", 1.2);
+ link.append("title").text(function(d) { return edgeTooltip(d); });
+
const node = g.append("g").selectAll("g")
.data(nodes)
.join("g")
@@ -344,9 +399,7 @@ def build_standalone_d3_html(
.attr("fill", function(d) { return d.color || "#666"; })
.attr("stroke", function(d) { return d.stroke || "#fff"; });
- node.append("title").text(function(d) {
- return (d.label || d.id) + " — degree " + (d.degree || 0);
- });
+ node.append("title").text(function(d) { return nodeTooltip(d); });
node.filter(function(d) { return String(d.id).length <= 12; })
.append("text")
diff --git a/topoexplorer/graph_metrics.py b/topoexplorer/graph_metrics.py
new file mode 100644
index 00000000..33b8fff5
--- /dev/null
+++ b/topoexplorer/graph_metrics.py
@@ -0,0 +1,363 @@
+"""Graph metrics for the Neighborhood Explorer (displayed-graph only).
+
+All metrics are computed on the *displayed* NetworkX graph (after sampling and
+lifting), on an undirected view so the same code path covers adjacency,
+incidence/bipartite, and layered views.
+
+Metric tiers
+------------
+- cheap (always): counts, density, degree stats, components, clustering,
+ transitivity, per-node degree centrality, per-edge Forman-Ricci.
+- default (size-gated, largest connected component): diameter, radius.
+- advanced (opt-in, expensive): betweenness, closeness, eccentricity,
+ eigenvector, pagerank, edge betweenness, degree assortativity.
+
+Forman-Ricci uses the Weber-Jost-Saucan (2018) definition; with unit weights
+(the displayed graph carries none) it reduces exactly to ``4 - deg(u) - deg(v)``
+(no triangle term), matching TopoBench's implementation.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+import networkx as nx
+
+# Above these sizes, distance measures (diameter/radius) are skipped by default
+# and betweenness is approximated via k-sampling.
+DISTANCE_NODE_LIMIT = 1200
+DISTANCE_EDGE_LIMIT = 6000
+BETWEENNESS_EXACT_LIMIT = 800
+BETWEENNESS_SAMPLE_K = 500
+
+
+def _edge_key(u, v):
+ """Order-independent key for an undirected edge (string node ids)."""
+ a, b = str(u), str(v)
+ return (a, b) if a <= b else (b, a)
+
+
+def forman_ricci_unweighted(UG: nx.Graph, u, v) -> int:
+ """Weber-Jost-Saucan Forman-Ricci with unit weights: ``4 - deg(u) - deg(v)``."""
+ return 4 - UG.degree(u) - UG.degree(v)
+
+
+def compute_graph_metrics(
+ G: nx.Graph,
+ *,
+ view_label: str = "graph",
+ expensive: bool = False,
+ betweenness_k: int | None = None,
+) -> dict[str, Any]:
+ """Compute displayed-graph metrics.
+
+ Parameters
+ ----------
+ G : networkx graph
+ The displayed graph (may be directed for incidence/bipartite views).
+ view_label : str
+ Human-readable view name (e.g. ``"adjacency"``) for UI captions.
+ expensive : bool
+ When True, also compute advanced per-node/edge centralities.
+ betweenness_k : int, optional
+ Sample size for approximate betweenness; defaults to an internal policy.
+
+ Returns
+ -------
+ dict
+ ``{"graph": {...}, "graph_scope": {...}, "nodes": {...}, "edges": {...},
+ "flags": {...}}``.
+ """
+ notes: list[str] = []
+ directed = G.is_directed()
+ UG = G.to_undirected() if directed else G
+ if directed:
+ notes.append(f"Computed on an undirected view of the {view_label} graph.")
+
+ n = UG.number_of_nodes()
+ m = UG.number_of_edges()
+ degrees = dict(UG.degree())
+ deg_values = list(degrees.values()) if degrees else [0]
+
+ graph: dict[str, Any] = {
+ "nodes": n,
+ "edges": m,
+ "density": nx.density(UG) if n > 1 else 0.0,
+ "avg_degree": (sum(deg_values) / n) if n else 0.0,
+ "max_degree": max(deg_values),
+ "min_degree": min(deg_values),
+ }
+
+ components = list(nx.connected_components(UG))
+ graph["components"] = len(components)
+ largest_cc_nodes = max(components, key=len) if components else set()
+ graph["largest_cc_size"] = len(largest_cc_nodes)
+
+ try:
+ graph["avg_clustering"] = nx.average_clustering(UG) if n else 0.0
+ except Exception:
+ graph["avg_clustering"] = None
+ try:
+ graph["transitivity"] = nx.transitivity(UG) if n else 0.0
+ except Exception:
+ graph["transitivity"] = None
+
+ graph_scope: dict[str, str] = {}
+
+ # Default tier: distance measures on the largest connected component.
+ graph["diameter"] = None
+ graph["radius"] = None
+ distance_skipped = n > DISTANCE_NODE_LIMIT or m > DISTANCE_EDGE_LIMIT
+ if distance_skipped:
+ notes.append(
+ "Diameter/radius skipped (graph too large); enable advanced metrics "
+ "or reduce sampling."
+ )
+ elif largest_cc_nodes:
+ cc = UG.subgraph(largest_cc_nodes)
+ try:
+ ecc = nx.eccentricity(cc)
+ graph["diameter"] = max(ecc.values()) if ecc else None
+ graph["radius"] = min(ecc.values()) if ecc else None
+ if graph["components"] > 1:
+ graph_scope["diameter"] = "largest CC"
+ graph_scope["radius"] = "largest CC"
+ except Exception:
+ notes.append("Diameter/radius unavailable for this graph.")
+
+ # Per-node cheap metrics.
+ nodes: dict[str, dict[str, Any]] = {}
+ deg_centrality = nx.degree_centrality(UG) if n else {}
+ try:
+ clustering = nx.clustering(UG)
+ except Exception:
+ clustering = {}
+ for node in UG.nodes():
+ key = str(node)
+ nodes[key] = {
+ "degree": degrees.get(node, 0),
+ "degree_centrality": deg_centrality.get(node),
+ "clustering": clustering.get(node),
+ }
+
+ # Per-edge cheap metrics (Forman-Ricci).
+ edges: dict[tuple, dict[str, Any]] = {}
+ for u, v in UG.edges():
+ edges[_edge_key(u, v)] = {"forman_ricci": forman_ricci_unweighted(UG, u, v)}
+
+ flags: dict[str, Any] = {
+ "view_label": view_label,
+ "expensive": bool(expensive),
+ "betweenness_approx": False,
+ "distance_skipped": distance_skipped,
+ "notes": notes,
+ }
+
+ if expensive and n:
+ _add_advanced_metrics(
+ UG, nodes, edges, graph, flags, betweenness_k=betweenness_k
+ )
+
+ return {
+ "graph": graph,
+ "graph_scope": graph_scope,
+ "nodes": nodes,
+ "edges": edges,
+ "flags": flags,
+ }
+
+
+def _add_advanced_metrics(UG, nodes, edges, graph, flags, *, betweenness_k=None):
+ """Augment node/edge dicts with expensive centralities (in place)."""
+ n = UG.number_of_nodes()
+
+ # Betweenness: exact for small graphs, k-sampled otherwise.
+ if betweenness_k is None:
+ betweenness_k = (
+ None if n <= BETWEENNESS_EXACT_LIMIT else min(n, BETWEENNESS_SAMPLE_K)
+ )
+ try:
+ if betweenness_k is not None and betweenness_k < n:
+ betw = nx.betweenness_centrality(UG, k=betweenness_k, seed=0)
+ flags["betweenness_approx"] = True
+ flags["notes"].append(
+ f"Betweenness approximated with k={betweenness_k} samples."
+ )
+ else:
+ betw = nx.betweenness_centrality(UG)
+ except Exception:
+ betw = {}
+
+ try:
+ ebetw = nx.edge_betweenness_centrality(UG)
+ except Exception:
+ ebetw = {}
+
+ try:
+ closeness = nx.closeness_centrality(UG)
+ except Exception:
+ closeness = {}
+
+ # Eccentricity per node on the largest CC (finite only there).
+ eccentricity: dict[Any, Any] = {}
+ try:
+ components = list(nx.connected_components(UG))
+ if components:
+ cc_nodes = max(components, key=len)
+ eccentricity = nx.eccentricity(UG.subgraph(cc_nodes))
+ except Exception:
+ eccentricity = {}
+
+ try:
+ eigenvector = nx.eigenvector_centrality(UG, max_iter=500, tol=1e-04)
+ except Exception:
+ eigenvector = {}
+ flags["notes"].append("Eigenvector centrality did not converge.")
+
+ try:
+ pagerank = nx.pagerank(UG)
+ except Exception:
+ pagerank = {}
+
+ for node in UG.nodes():
+ key = str(node)
+ slot = nodes.setdefault(key, {})
+ slot["betweenness"] = betw.get(node)
+ slot["closeness"] = closeness.get(node)
+ slot["eccentricity"] = eccentricity.get(node)
+ slot["eigenvector"] = eigenvector.get(node)
+ slot["pagerank"] = pagerank.get(node)
+
+ for (u, v), val in ebetw.items():
+ edges.setdefault(_edge_key(u, v), {})["edge_betweenness"] = val
+
+ try:
+ graph["assortativity"] = nx.degree_assortativity_coefficient(UG)
+ except Exception:
+ graph["assortativity"] = None
+
+
+# ---------------------------------------------------------------------------
+# Formatting helpers (kept here so the app stays lean and the format is shared)
+# ---------------------------------------------------------------------------
+
+def fmt_value(v: Any) -> str:
+ """Compact human-readable formatting for metric values."""
+ if v is None:
+ return "N/A"
+ if isinstance(v, bool):
+ return str(v)
+ if isinstance(v, int):
+ return f"{v:,}"
+ if isinstance(v, float):
+ if v == 0:
+ return "0"
+ if abs(v) >= 1000 or abs(v) < 0.001:
+ return f"{v:.2e}"
+ return f"{v:.3f}".rstrip("0").rstrip(".")
+ return str(v)
+
+
+# Labels for whole-graph metrics, in display order.
+GRAPH_METRIC_LABELS = [
+ ("nodes", "Nodes"),
+ ("edges", "Edges"),
+ ("density", "Density"),
+ ("avg_degree", "Avg degree"),
+ ("max_degree", "Max degree"),
+ ("min_degree", "Min degree"),
+ ("components", "Components"),
+ ("largest_cc_size", "Largest CC size"),
+ ("avg_clustering", "Avg clustering"),
+ ("transitivity", "Transitivity"),
+ ("diameter", "Diameter"),
+ ("radius", "Radius"),
+ ("assortativity", "Degree assortativity"),
+]
+
+# Compact subset for the floating HUD overlay.
+HUD_METRIC_KEYS = ["nodes", "edges", "density", "avg_degree", "components", "diameter"]
+
+NODE_TOOLTIP_LABELS = [
+ ("degree_centrality", "deg cent"),
+ ("clustering", "clustering"),
+ ("betweenness", "betw"),
+ ("closeness", "closeness"),
+ ("eccentricity", "ecc"),
+]
+
+EDGE_TOOLTIP_LABELS = [
+ ("forman_ricci", "Forman"),
+ ("edge_betweenness", "edge betw"),
+]
+
+
+def build_hud_rows(metrics: dict) -> list[dict]:
+ """Compact ``[{label, value}]`` rows for the canvas HUD overlay."""
+ graph = metrics.get("graph", {})
+ scope = metrics.get("graph_scope", {})
+ label_map = dict(GRAPH_METRIC_LABELS)
+ rows = []
+ for key in HUD_METRIC_KEYS:
+ if key not in graph or graph.get(key) is None:
+ continue
+ label = label_map.get(key, key)
+ if key in scope:
+ label = f"{label} ({scope[key]})"
+ rows.append({"label": label, "value": fmt_value(graph[key])})
+ return rows
+
+
+def node_payload_metrics(metrics: dict, node_id: str) -> dict:
+ """Formatted per-node metric strings for a D3 node tooltip."""
+ nd = (metrics.get("nodes") or {}).get(str(node_id))
+ if not nd:
+ return {}
+ out = {}
+ for key, label in NODE_TOOLTIP_LABELS:
+ if key in nd and nd[key] is not None:
+ out[label] = fmt_value(nd[key])
+ return out
+
+
+def edge_payload_metrics(metrics: dict, u: str, v: str) -> dict:
+ """Formatted per-edge metric strings for a D3 edge tooltip."""
+ ed = (metrics.get("edges") or {}).get(_edge_key(u, v))
+ if not ed:
+ return {}
+ out = {}
+ for key, label in EDGE_TOOLTIP_LABELS:
+ if key in ed and ed[key] is not None:
+ out[label] = fmt_value(ed[key])
+ return out
+
+
+def summarize_metrics(metrics: dict, centrality: str = "degree_centrality",
+ top_k: int = 5) -> dict:
+ """Top-k nodes by a centrality and Forman-Ricci extremes for the Explore tab."""
+ nodes = metrics.get("nodes") or {}
+ edges = metrics.get("edges") or {}
+
+ ranked = [
+ (nid, vals.get(centrality))
+ for nid, vals in nodes.items()
+ if vals.get(centrality) is not None
+ ]
+ ranked.sort(key=lambda kv: kv[1], reverse=True)
+ top_nodes = ranked[:top_k]
+
+ forman = [
+ (key, vals["forman_ricci"])
+ for key, vals in edges.items()
+ if vals.get("forman_ricci") is not None
+ ]
+ forman.sort(key=lambda kv: kv[1])
+ most_negative = forman[:top_k]
+ most_positive = list(reversed(forman[-top_k:])) if forman else []
+
+ return {
+ "centrality": centrality,
+ "top_nodes": top_nodes,
+ "forman_most_negative": most_negative,
+ "forman_most_positive": most_positive,
+ }
diff --git a/topoexplorer/neighborhood_explorer_app.py b/topoexplorer/neighborhood_explorer_app.py
index 7c212865..66a56a17 100644
--- a/topoexplorer/neighborhood_explorer_app.py
+++ b/topoexplorer/neighborhood_explorer_app.py
@@ -30,6 +30,7 @@
import torch
from d3_graph_html import build_standalone_d3_html
+import graph_metrics as gm
from omegaconf import OmegaConf
from torch_geometric.utils import to_undirected
import rootutils
@@ -178,8 +179,9 @@ def extract_dataset_metadata(domain, dataset_name):
initial_sidebar_state="expanded"
)
-# ~2× Streamlit’s default sidebar width (~256px in recent releases).
-_SIDEBAR_TARGET_WIDTH_PX = 512
+# Fixed, slightly-wider sidebar. The native collapse arrow is kept as a
+# "focus mode" so the user can momentarily hide controls for a full-width graph.
+_SIDEBAR_TARGET_WIDTH_PX = 560
st.markdown(
f"""
""",
unsafe_allow_html=True,
@@ -1654,26 +1693,6 @@ def launch_html_in_browser(path: Path) -> bool:
return False
-def open_d3_graph_window(payload):
- """Write standalone HTML to a temp file and open in browser."""
- if payload is None:
- st.warning("No graph to display.")
- return
- sel_ids = st.session_state.get("selected_neighborhood_ids") or []
- marker = "+".join(sel_ids) if sel_ids else None
- html_doc = build_standalone_d3_html(payload, cache_marker=marker)
- st.session_state["_d3_last_html"] = html_doc
- path = Path(tempfile.gettempdir()) / f"topobench_graph_{uuid.uuid4().hex}.html"
- path.write_text(html_doc, encoding="utf-8")
- if launch_html_in_browser(path):
- st.success("Opened graph in your browser.")
- else:
- st.warning(
- "Could not launch a browser automatically. Use **Download last D3 graph** "
- "below and open the file in Edge or Chrome (not VS Code preview)."
- )
-
-
# ============================================================================
# Lifting Application
# ============================================================================
@@ -2038,7 +2057,7 @@ def render_basic_lifting_editor(selected_lifting):
if st.button(
"Reset edited config to defaults",
key=f"cfg_reset::{editor_id}",
- use_container_width=True,
+ width="stretch",
):
for k in BASIC_EDITABLE_KEYS:
wk = f"editcfg::{editor_id}::{k}"
@@ -2063,7 +2082,26 @@ def render_basic_lifting_editor(selected_lifting):
# Streamlit App
# ============================================================================
-D3_EMBED_HEIGHT = 760
+# Graph block height for ``components.html`` (header + chart inside the iframe).
+D3_EMBED_HEIGHT = 820
+
+
+def _format_graph_header_title(vdesc, dataset_name, lift_line, num_nodes) -> str:
+ """Single-line title for the D3 embed header (View, Dataset, Lift, Nodes)."""
+ nodes_str = str(num_nodes) if num_nodes is not None else "N/A"
+ return (
+ f"View: {vdesc} Dataset: {dataset_name} | {lift_line} | Nodes: {nodes_str}"
+ )
+
+
+def _node_count_from_dataset(dset0):
+ """Return node count for header display, or None if unavailable."""
+ if dset0 is None or not hasattr(dset0, "num_nodes"):
+ return None
+ num_nodes = dset0.num_nodes
+ if num_nodes is None and getattr(dset0, "edge_index", None) is not None:
+ num_nodes = int(dset0.edge_index.max().item()) + 1
+ return num_nodes
def _render_dataset_metadata_card(domain, dataset_name):
@@ -2093,19 +2131,38 @@ def _render_left_config(available_datasets):
st.header("Data configuration")
with st.expander("Dataset", expanded=True):
+ # Persist domain/dataset across sidebar tab switches via shadow keys
+ # (these selectboxes can't use a widget key because the Dataset options
+ # change with the Domain, which would invalidate a stored key value).
+ domain_options = list(available_datasets.keys())
+ prev_domain = st.session_state.get("cfg_domain")
+ domain_index = (
+ domain_options.index(prev_domain)
+ if prev_domain in domain_options
+ else 0
+ )
selected_domain = st.selectbox(
"Domain",
- options=list(available_datasets.keys()),
- index=0,
+ options=domain_options,
+ index=domain_index,
help="Topological domain (folder under configs/dataset).",
)
+ st.session_state["cfg_domain"] = selected_domain
+
datasets_in_domain = available_datasets.get(selected_domain, [])
+ prev_dataset = st.session_state.get("cfg_dataset")
+ dataset_index = (
+ datasets_in_domain.index(prev_dataset)
+ if prev_dataset in datasets_in_domain
+ else 0
+ )
selected_dataset = st.selectbox(
"Dataset",
options=datasets_in_domain,
- index=0,
+ index=dataset_index,
help=f"YAML stem under configs/dataset/{selected_domain}/",
)
+ st.session_state["cfg_dataset"] = selected_dataset
st.caption(f"**{selected_domain}** / **{selected_dataset}**")
_render_dataset_metadata_card(selected_domain, selected_dataset)
@@ -2177,21 +2234,11 @@ def _render_left_config(available_datasets):
st.session_state["edited_lifting_config"] = edited_lifting_config
st.session_state["edited_lifting_errors"] = edited_lifting_errors
- st.toggle(
- "3D layered view (orbit/zoom)",
- help=(
- "Stack ranks as horizontal planes in 3D (multi-incidence, 2+ adjacency, "
- "or combined incidence+adjacency). Single-matrix views stay 2D."
- ),
- key="layered_3d_view",
- on_change=_on_layered_3d_toggle,
- )
-
st.subheader("Actions")
load_clicked = st.button(
"Load graph",
type="primary",
- use_container_width=True,
+ width="stretch",
key="load_graph_btn",
)
@@ -2368,6 +2415,16 @@ def _on_layered_3d_toggle():
_rebuild_embed_for_neighborhoods(ids)
+def _on_metrics_option_change():
+ """Re-embed when a metrics toggle changes (HUD visibility / advanced compute)."""
+ if st.session_state.get("data") is None:
+ return
+ ids = list(st.session_state.get("selected_neighborhood_ids") or [])
+ if not ids:
+ return
+ _rebuild_embed_for_neighborhoods(ids)
+
+
def _on_adjacency_toggle():
"""Toggling an adjacency checkbox keeps the union of checked
incidences and adjacencies."""
@@ -2406,6 +2463,66 @@ def _commit_selection(new_ids):
_rebuild_embed_for_neighborhoods(new_ids)
+def _metrics_marker(neigh_ids, expensive):
+ """Cache key for displayed-graph metrics.
+
+ Includes everything that changes the displayed graph (selection, sampling
+ snapshot, graph index, lifting) plus the expensive flag. Deliberately
+ excludes the 2D/3D toggle and the HUD on/off toggle so those never trigger
+ a recompute.
+ """
+ return "|".join(
+ [
+ "+".join(neigh_ids),
+ f"max={st.session_state.get('_loaded_max_nodes')}",
+ f"min={st.session_state.get('_loaded_min_degree')}",
+ f"idx={st.session_state.get('active_graph_index')}",
+ f"lift={st.session_state.get('lifting_applied')}",
+ f"exp={bool(expensive)}",
+ ]
+ )
+
+
+def _compute_and_attach_metrics(payload, G_load, neigh_ids):
+ """Compute displayed-graph metrics (cached) and attach them to the payload."""
+ if not isinstance(payload, dict):
+ return
+
+ expensive = bool(st.session_state.get("metrics_expensive", False))
+ marker = _metrics_marker(neigh_ids, expensive)
+ cached = st.session_state.get("_graph_metrics")
+ if cached and cached.get("marker") == marker:
+ metrics = cached.get("data")
+ else:
+ try:
+ metrics = gm.compute_graph_metrics(
+ G_load,
+ view_label=payload.get("graphType", "graph"),
+ expensive=expensive,
+ )
+ st.session_state.pop("_graph_metrics_error", None)
+ except Exception as e: # never break the embed over metrics
+ metrics = None
+ st.session_state["_graph_metrics_error"] = str(e)
+ st.session_state["_graph_metrics"] = {"marker": marker, "data": metrics}
+
+ if not metrics:
+ return
+
+ payload["graphMetrics"] = gm.build_hud_rows(metrics)
+ payload["showMetricsHud"] = bool(
+ st.session_state.get("show_metrics_hud", True)
+ )
+ for nd in payload.get("nodes", []):
+ nm = gm.node_payload_metrics(metrics, nd.get("id"))
+ if nm:
+ nd["metrics"] = nm
+ for ln in payload.get("links", []):
+ em = gm.edge_payload_metrics(metrics, ln.get("source"), ln.get("target"))
+ if em:
+ ln["metrics"] = em
+
+
def _rebuild_embed_for_neighborhoods(neigh_ids):
"""Rebuild the embedded D3 view for a list of neighborhood ids.
@@ -2731,6 +2848,16 @@ def _rebuild_embed_for_neighborhoods(neigh_ids):
cache_marker = "+".join(neigh_ids)
if st.session_state.get("layered_3d_view"):
cache_marker += ":3d"
+
+ dataset_name = st.session_state.get("dataset_name") or "—"
+ num_nodes = _node_count_from_dataset(dset0)
+ payload["title"] = _format_graph_header_title(
+ vdesc, dataset_name, lift_subtitle, num_nodes
+ )
+ payload["subtitle"] = ""
+
+ _compute_and_attach_metrics(payload, G_load, neigh_ids)
+
embed_html = build_standalone_d3_html(
payload,
embed=True,
@@ -2754,8 +2881,18 @@ def _rebuild_embed_for_neighborhoods(neigh_ids):
return True
-def _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name):
- """Populate neighborhoods, sampling, and embed for a working sample."""
+def _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name,
+ sync_widgets=True):
+ """Populate neighborhoods, sampling, and embed for a working sample.
+
+ ``sync_widgets`` controls whether the sampling widget keys (``ui_min_degree``,
+ ``ui_rank_cap_*``, ``ui_hyperedge_cap``) are written. This is safe only on
+ the initial load or from on_change callbacks (before widgets are
+ re-instantiated). It must be ``False`` when called inline after those
+ widgets already exist in the current run (e.g. the "Apply to all" button),
+ otherwise Streamlit raises "cannot be modified after the widget is
+ instantiated".
+ """
rank_labels_for_payload = get_rank_labels(
loaded_domain, dataset_name, dset0
)
@@ -2802,15 +2939,16 @@ def _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name):
hyperedge_cap = max(0, min(int(hyperedge_cap), int(num_hyperedges)))
min_degree = int(cfg.get("min_degree", 0))
- for rank, cap in caps_by_rank.items():
- st.session_state[f"ui_rank_cap_{rank}"] = int(cap)
- if (
- hyperedge_cap is not None
- and num_hyperedges is not None
- and int(num_hyperedges) > 1
- ):
- st.session_state["ui_hyperedge_cap"] = int(hyperedge_cap)
- st.session_state["ui_min_degree"] = min_degree
+ if sync_widgets:
+ for rank, cap in caps_by_rank.items():
+ st.session_state[f"ui_rank_cap_{rank}"] = int(cap)
+ if (
+ hyperedge_cap is not None
+ and num_hyperedges is not None
+ and int(num_hyperedges) > 1
+ ):
+ st.session_state["ui_hyperedge_cap"] = int(hyperedge_cap)
+ st.session_state["ui_min_degree"] = min_degree
st.session_state["rank_populations"] = rank_pops
st.session_state["hyperedge_population"] = num_hyperedges
@@ -2994,7 +3132,7 @@ def _render_rank_cap_sliders(rank_populations, hyperedge_population):
st.write("")
if st.button(
"Apply to all",
- use_container_width=True,
+ width="stretch",
key="ui_apply_all_rank_caps",
):
for rank, pop in rank_populations.items():
@@ -3089,19 +3227,17 @@ def _render_inline_rank_cap(rank_populations, hyperedge_population):
def _render_rank_cap_controls(rank_populations, hyperedge_population):
- """Render rank-cap UI (inline or popover); return (ui_rank_caps, ui_hyperedge_cap)."""
+ """Render rank-cap UI (inline or expander); return (ui_rank_caps, ui_hyperedge_cap)."""
if not rank_populations:
return {}, None
configurable = [
rank for rank, pop in rank_populations.items() if int(pop) >= 2
]
- use_popover = len(configurable) >= 2
+ use_expander = len(configurable) >= 2
- if use_popover:
- summary = _format_rank_cap_summary(rank_populations, hyperedge_population)
- st.caption(summary)
- with st.popover("Per-rank caps", use_container_width=True):
+ if use_expander:
+ with st.expander("Per-rank caps", expanded=True):
return _render_rank_cap_sliders(rank_populations, hyperedge_population)
return _render_inline_rank_cap(rank_populations, hyperedge_population)
@@ -3174,7 +3310,9 @@ def _on_sampling_control_change():
dset0 = data[0] if hasattr(data, "__getitem__") else data
loaded_domain = st.session_state.get("data_domain")
dataset_name = st.session_state.get("dataset_name")
- ok = _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name)
+ ok = _finalize_loaded_sample(
+ dset0, cfg, loaded_domain, dataset_name, sync_widgets=False
+ )
if ok:
snap = st.session_state.get("_load_cfg_snapshot") or {}
snap["caps_by_rank"] = copy.deepcopy(cfg.get("caps_by_rank") or {})
@@ -3252,41 +3390,38 @@ def _render_graph_sample_section():
if st.session_state.get("data") is None:
return
- st.subheader("Graph sample")
-
_dset0, rank_populations, hyperedge_population, _is_loaded = (
_get_sampling_context()
)
meta = st.session_state.get("dataset_metadata") or {}
is_inductive = (meta.get("learning_setting") or "").lower() == "inductive"
- with st.container(border=True):
- col_left, col_right = st.columns([0.4, 0.6], gap="large")
-
- with col_left:
- if is_inductive:
- total = int(st.session_state.get("loaded_dataset_size", 1) or 1)
- max_idx = max(0, total - 1)
- active = int(st.session_state.get("active_graph_index", 0))
- err = st.session_state.get("_graph_index_error")
- _render_graph_index_input(active, error=err)
- st.caption(
- f"**Available graphs:** `0` to `{max_idx}` ({total} total)"
- )
- else:
- st.caption("Single graph (transductive dataset)")
-
- st.slider(
- "Minimum degree",
- min_value=0,
- max_value=20,
- value=int(st.session_state.get("ui_min_degree", 0)),
- key="ui_min_degree",
- on_change=_on_sampling_control_change,
- )
+ if is_inductive:
+ total = int(st.session_state.get("loaded_dataset_size", 1) or 1)
+ max_idx = max(0, total - 1)
+ active = int(st.session_state.get("active_graph_index", 0))
+ err = st.session_state.get("_graph_index_error")
+ _render_graph_index_input(active, error=err)
+ st.caption(
+ f"**Available graphs:** `0` to `{max_idx}` ({total} total)"
+ )
+ else:
+ st.caption("Single graph (transductive dataset)")
+
+ st.slider(
+ "Minimum degree",
+ min_value=0,
+ max_value=20,
+ value=int(st.session_state.get("ui_min_degree", 0)),
+ key="ui_min_degree",
+ help="Drop nodes below this degree after rank caps are applied.",
+ on_change=_on_sampling_control_change,
+ )
- with col_right:
- _render_rank_cap_controls(rank_populations, hyperedge_population)
+ if rank_populations:
+ summary = _format_rank_cap_summary(rank_populations, hyperedge_population)
+ st.caption(summary)
+ _render_rank_cap_controls(rank_populations, hyperedge_population)
def _do_load_graph(cfg, progress=None):
@@ -3361,99 +3496,270 @@ def _do_load_graph(cfg, progress=None):
)
-def _render_graph_view_rest():
- """Neighborhood picker and D3 embed (after graph sample section)."""
+def _render_sidebar_tab_selector(data_loaded):
+ """Large three-tab selector backed by session state (survives reruns)."""
+ options = ["Load graph", "Explore", "Metrics"]
+
+ # Apply a programmatic switch requested on a previous run (e.g. after a
+ # successful load). Must happen BEFORE the tab buttons are rendered.
+ pending = st.session_state.pop("_pending_tab", None)
+ if pending in options:
+ st.session_state["active_sidebar_tab"] = pending
+ if "active_sidebar_tab" not in st.session_state:
+ st.session_state["active_sidebar_tab"] = "Load graph"
+
+ active = st.session_state.get("active_sidebar_tab") or "Load graph"
+ col_load, col_explore, col_metrics = st.columns(3)
+ with col_load:
+ if st.button(
+ "Load graph",
+ key="tab_load_btn",
+ type="primary" if active == "Load graph" else "secondary",
+ width="stretch",
+ ):
+ st.session_state["active_sidebar_tab"] = "Load graph"
+ st.rerun()
+ with col_explore:
+ if st.button(
+ "Explore",
+ key="tab_explore_btn",
+ type="primary" if active == "Explore" else "secondary",
+ width="stretch",
+ ):
+ st.session_state["active_sidebar_tab"] = "Explore"
+ st.rerun()
+ with col_metrics:
+ if st.button(
+ "Metrics",
+ key="tab_metrics_btn",
+ type="primary" if active == "Metrics" else "secondary",
+ width="stretch",
+ ):
+ st.session_state["active_sidebar_tab"] = "Metrics"
+ st.rerun()
+
+ if not data_loaded:
+ st.caption("Load a graph to enable **Explore** and **Metrics**.")
+
+ return active
+
+
+def _render_explore_tab():
+ """Post-load controls: 3D view, graph sample, and neighborhoods."""
+ if st.session_state.get("data") is None:
+ st.info("Load a graph from the **Load graph** tab to explore neighborhoods.")
+ return
+
+ st.toggle(
+ "3D layered view (orbit/zoom)",
+ help=(
+ "Stack ranks as horizontal planes in 3D (multi-incidence, 2+ adjacency, "
+ "or combined incidence+adjacency). Single-matrix views stay 2D."
+ ),
+ key="layered_3d_view",
+ on_change=_on_layered_3d_toggle,
+ )
+
+ st.divider()
+ st.subheader("Graph sample")
+ _render_graph_sample_section()
+
+ st.divider()
_render_neighborhood_picker()
- embed_html = st.session_state.get("_d3_embed_html")
- if not embed_html:
- st.info(
- "Use the **sidebar** to choose dataset and lifting options, then click "
- "**Load graph**."
- )
+
+def _render_metrics_tab():
+ """Displayed-graph metrics tab: HUD toggle, whole-graph table, top elements."""
+ if st.session_state.get("data") is None:
+ st.info("Load a graph from the **Load graph** tab to see metrics.")
return
- data = st.session_state.get("data")
- dset0 = (
- data[0] if (data is not None and hasattr(data, "__getitem__")) else data
+ err = st.session_state.get("_graph_metrics_error")
+ if err:
+ st.caption(f"Metrics unavailable: {err}")
+ return
+
+ cached = st.session_state.get("_graph_metrics") or {}
+ metrics = cached.get("data")
+ if not metrics:
+ st.caption("Metrics will appear once a graph is displayed.")
+ return
+
+ st.toggle(
+ "Show metrics overlay",
+ value=bool(st.session_state.get("show_metrics_hud", True)),
+ key="show_metrics_hud",
+ help="Compact whole-graph metrics in the top-left of the graph canvas.",
+ on_change=_on_metrics_option_change,
)
- lifting_applied = st.session_state.get("lifting_applied")
- if dset0 is not None and hasattr(dset0, "num_nodes"):
- num_nodes = dset0.num_nodes
- if (
- num_nodes is None
- and getattr(dset0, "edge_index", None) is not None
- ):
- num_nodes = int(dset0.edge_index.max().item()) + 1
- num_features = (
- dset0.x.shape[1]
- if (
- hasattr(dset0, "x")
- and dset0.x is not None
- and len(dset0.x.shape) > 1
- )
- else 0
+
+ flags = metrics.get("flags") or {}
+ for note in flags.get("notes", []):
+ st.caption(note)
+
+ graph = metrics.get("graph") or {}
+ scope = metrics.get("graph_scope") or {}
+ label_map = dict(gm.GRAPH_METRIC_LABELS)
+ rows = []
+ for key, _label in gm.GRAPH_METRIC_LABELS:
+ if key not in graph or graph.get(key) is None:
+ continue
+ label = label_map.get(key, key)
+ if key in scope:
+ label = f"{label} ({scope[key]})"
+ rows.append({"Metric": label, "Value": gm.fmt_value(graph[key])})
+ if rows:
+ st.caption("Displayed graph (after sampling)")
+ st.dataframe(rows, hide_index=True, width="stretch")
+
+ _render_metrics_top_elements(metrics)
+
+ with st.expander("Advanced metrics"):
+ st.checkbox(
+ "Compute advanced metrics (betweenness, closeness, eccentricity, …)",
+ value=bool(st.session_state.get("metrics_expensive", False)),
+ key="metrics_expensive",
+ help=(
+ "Heavier centralities. May be slow or approximated on large "
+ "graphs; betweenness uses sampling above ~800 nodes."
+ ),
+ on_change=_on_metrics_option_change,
)
- if lifting_applied:
- st.success(
- f"Viewing lifted data — **{lifting_applied['name']}** "
- f"({lifting_applied['source']} → {lifting_applied['target']}) | "
- f"Nodes: {num_nodes}, "
- f"Node features: {num_features if num_features else 'N/A'}"
- )
- else:
- st.info(
- f"Dataset loaded — Nodes: {num_nodes}, "
- f"Node features: {num_features if num_features else 'N/A'}"
+ if flags.get("betweenness_approx"):
+ st.caption("Betweenness is approximated (k-sampled) for this graph.")
+ if not flags.get("expensive"):
+ st.caption("Enable to add per-node/edge centralities to tooltips and tables.")
+
+
+def _render_metrics_top_elements(metrics):
+ """Top-k nodes by a chosen centrality plus Forman-Ricci extremes."""
+ nodes = metrics.get("nodes") or {}
+ if not nodes:
+ return
+
+ sample = next(iter(nodes.values()), {})
+ centrality_opts = [
+ (k, lbl)
+ for k, lbl in [
+ ("degree_centrality", "Degree centrality"),
+ ("betweenness", "Betweenness"),
+ ("closeness", "Closeness"),
+ ("eigenvector", "Eigenvector"),
+ ("pagerank", "PageRank"),
+ ("clustering", "Clustering"),
+ ]
+ if k in sample and sample.get(k) is not None
+ ]
+ if centrality_opts:
+ keys = [k for k, _ in centrality_opts]
+ labels = {k: lbl for k, lbl in centrality_opts}
+ choice = st.selectbox(
+ "Top nodes by",
+ options=keys,
+ format_func=lambda k: labels.get(k, k),
+ key="metrics_top_centrality",
+ )
+ summary = gm.summarize_metrics(metrics, centrality=choice, top_k=5)
+ top_rows = [
+ {"Node": nid, labels.get(choice, choice): gm.fmt_value(val)}
+ for nid, val in summary["top_nodes"]
+ ]
+ if top_rows:
+ st.dataframe(top_rows, hide_index=True, width="stretch")
+
+ edges = metrics.get("edges") or {}
+ if edges:
+ summary = gm.summarize_metrics(metrics, top_k=5)
+ neg = summary["forman_most_negative"]
+ if neg:
+ st.caption("Most negative Forman-Ricci edges")
+ st.dataframe(
+ [
+ {"Edge": f"{u} — {v}", "Forman": gm.fmt_value(val)}
+ for (u, v), val in neg
+ ],
+ hide_index=True,
+ width="stretch",
)
+
+def _render_graph_canvas():
+ """Main area: D3 embed (title + graph) and download directly below."""
+ embed_html = st.session_state.get("_d3_embed_html")
+ if not embed_html:
+ return
+
sel_ids = list(st.session_state.get("selected_neighborhood_ids") or [])
- avail = st.session_state.get("available_neighborhoods") or []
- if sel_ids:
- labels = [next((n["label"] for n in avail if n["id"] == s), s) for s in sel_ids]
- joined_ids = ", ".join(f"`{s}`" for s in sel_ids)
- joined_labels = "; ".join(labels)
- st.markdown(f"**Currently displayed:** {joined_ids} — {joined_labels}")
-
- # ``components.html`` has no ``key=`` in Streamlit 1.50; wrap in a keyed
- # container so changing the neighborhood remounts the iframe subtree.
neigh_key = "+".join(sel_ids) if sel_ids else "default"
with st.container(key=f"d3_embed::{neigh_key}"):
- components.html(embed_html, height=D3_EMBED_HEIGHT, scrolling=False)
-
- payload = st.session_state.get("_d3_payload")
- if payload is not None and st.button(
- "Open graph in new browser window",
- key="d3_open_current",
- use_container_width=True,
- ):
- open_d3_graph_window(payload)
+ st.iframe(embed_html, height=D3_EMBED_HEIGHT)
last_html = st.session_state.get("_d3_last_html")
if last_html:
st.download_button(
- label="Download last D3 graph (HTML)",
+ label="Download graph (HTML)",
data=last_html,
file_name="topobench_graph.html",
mime="text/html",
key="d3_download_main",
- use_container_width=True,
+ width="stretch",
)
+_PERSIST_WIDGET_KEYS = frozenset(
+ {
+ "use_lifting",
+ "ui_min_degree",
+ "ui_hyperedge_cap",
+ "ui_set_all_rank_caps",
+ "layered_3d_view",
+ "main_graph_index_input",
+ "show_metrics_hud",
+ "metrics_expensive",
+ }
+)
+
+
+def _persist_widget_state():
+ """Keep sidebar widget selections alive across tab switches.
+
+ Streamlit drops the session_state entry of any keyed widget that is not
+ rendered on a run (e.g. Explore/Metrics widgets while the Load graph tab is
+ showing). Re-assigning those keys here -- before any widget is instantiated
+ this run -- opts them out of that cleanup so selections survive when the
+ user switches sidebar tabs.
+ """
+ for key in list(st.session_state.keys()):
+ if key in _PERSIST_WIDGET_KEYS or key.startswith("ui_rank_cap_"):
+ st.session_state[key] = st.session_state[key]
+
+
def main():
+ _persist_widget_state()
+
flash = st.session_state.pop("_flash_ok", None)
if flash:
st.success(flash)
available_datasets = discover_available_datasets()
+ data_loaded = st.session_state.get("data") is not None
+ sidebar_cfg = None
with st.sidebar:
- sidebar_cfg = _render_left_config(available_datasets)
-
- st.header("Graph")
+ st.markdown(
+ '',
+ unsafe_allow_html=True,
+ )
+ tab = _render_sidebar_tab_selector(data_loaded)
+ st.divider()
+ if tab == "Load graph":
+ sidebar_cfg = _render_left_config(available_datasets)
+ elif tab == "Metrics":
+ _render_metrics_tab()
+ else:
+ _render_explore_tab()
- if sidebar_cfg["load_clicked"]:
+ if sidebar_cfg is not None and sidebar_cfg["load_clicked"]:
load_cfg = {**sidebar_cfg, **_default_load_sampling_cfg()}
with st.status("Loading graph…", expanded=True) as status:
def _progress(msg):
@@ -3466,12 +3772,10 @@ def _progress(msg):
else:
status.update(label="Load stopped", state="error")
if ok:
+ st.session_state["_pending_tab"] = "Explore"
st.rerun()
- if st.session_state.get("data") is not None:
- _render_graph_sample_section()
-
- _render_graph_view_rest()
+ _render_graph_canvas()
if __name__ == "__main__":