diff --git a/topoexplorer/d3_graph_html.py b/topoexplorer/d3_graph_html.py index 2c5b6a68..3809c5b2 100644 --- a/topoexplorer/d3_graph_html.py +++ b/topoexplorer/d3_graph_html.py @@ -83,6 +83,7 @@ def build_standalone_d3_html( header {{ padding: 12px 16px; background: #fff; border-bottom: 1px solid #ccc; }} header h1 {{ margin: 0; font-size: 1.1rem; }} header p {{ margin: 4px 0 0; color: #555; font-size: 0.85rem; }} + header p:empty {{ display: none; margin: 0; }} #chart-wrap {{ position: relative; {chart_size_css} }} #chart {{ width: 100%; height: 100%; min-height: inherit; }} #err {{ color: #b00; padding: 16px; white-space: pre-wrap; font-family: monospace; }} @@ -105,6 +106,20 @@ def build_standalone_d3_html( #legend .legend-dot {{ display: inline-block; width: 12px; height: 12px; border-radius: 50%; border: 1px solid rgba(0,0,0,0.18); flex-shrink: 0; }} #legend .legend-label {{ color: #222; }} + #metrics-hud {{ + display: none; position: absolute; top: 10px; left: 10px; z-index: 5; + background: rgba(255, 255, 255, 0.92); border: 1px solid #d0d0d0; + border-radius: 6px; padding: 8px 10px; font-size: 12px; line-height: 1.4; + box-shadow: 0 1px 4px rgba(0,0,0,0.08); pointer-events: none; + max-width: min(240px, 60%); + }} + #metrics-hud .hud-title {{ display: block; font-weight: 600; font-size: 11px; + color: #555; letter-spacing: 0.02em; text-transform: uppercase; + margin-bottom: 6px; }} + #metrics-hud .hud-row {{ display: flex; justify-content: space-between; + gap: 12px; margin: 2px 0; }} + #metrics-hud .hud-label {{ color: #555; }} + #metrics-hud .hud-value {{ color: #111; font-variant-numeric: tabular-nums; }} @@ -114,6 +129,7 @@ def build_standalone_d3_html(
+
@@ -127,6 +143,20 @@ def build_standalone_d3_html( function showError(msg) { document.getElementById("err").textContent = msg; } + function metricLines(obj) { + if (!obj) return ""; + var parts = []; + Object.keys(obj).forEach(function(k) { parts.push(k + ": " + obj[k]); }); + return parts.length ? ("\\n" + parts.join("\\n")) : ""; + } + function nodeTooltip(d) { + return (d.label || d.id) + " — degree " + (d.degree || 0) + metricLines(d.metrics); + } + function edgeTooltip(l) { + var s = (l.source && l.source.id) ? l.source.id : l.source; + var t = (l.target && l.target.id) ? l.target.id : l.target; + return s + " — " + t + metricLines(l.metrics); + } try { const raw = document.getElementById("graph-payload").textContent; const payload = JSON.parse(raw); @@ -154,6 +184,26 @@ def build_standalone_d3_html( box.style.display = "block"; })(); + (function renderMetricsHud() { + var box = document.getElementById("metrics-hud"); + if (!box) return; + var rows = payload.graphMetrics || []; + if (!payload.showMetricsHud || rows.length === 0) { + box.innerHTML = ""; + box.style.display = "none"; + return; + } + var html = 'Displayed graph'; + rows.forEach(function(r) { + html += '
' + + '' + r.label + '' + + '' + r.value + '' + + '
'; + }); + box.innerHTML = html; + box.style.display = "block"; + })(); + const nodes = (payload.nodes || []).map(function(d) { const o = Object.assign({}, d); o.id = String(o.id); @@ -190,7 +240,8 @@ def build_standalone_d3_html( return { source: String(l.source), target: String(l.target), - color: l.color + color: l.color, + metrics: l.metrics }; }); const graph3d = ForceGraph3D()(chart) @@ -199,9 +250,10 @@ def build_standalone_d3_html( .backgroundColor("#fafafa") .nodeRelSize(4) .nodeColor(function(n) { return n.color || "#666"; }) - .nodeLabel(function(n) { return (n.label || n.id) + " — degree " + (n.degree || 0); }) + .nodeLabel(function(n) { return nodeTooltip(n); }) .nodeVal(function(n) { return 1 + Math.log1p(n.degree || 1); }) .linkColor(function(l) { return l.color || "#888"; }) + .linkLabel(function(l) { return edgeTooltip(l); }) .linkOpacity(0.6) .linkWidth(0.6) .graphData({ nodes: nodes3d, links: links3d }); @@ -234,7 +286,8 @@ def build_standalone_d3_html( return { source: s, target: t, - color: l.color + color: l.color, + metrics: l.metrics }; }).filter(function(l) { return l.source && l.target; }); @@ -320,6 +373,8 @@ def build_standalone_d3_html( .attr("stroke", function(d) { return d.color || "#888"; }) .attr("stroke-width", 1.2); + link.append("title").text(function(d) { return edgeTooltip(d); }); + const node = g.append("g").selectAll("g") .data(nodes) .join("g") @@ -344,9 +399,7 @@ def build_standalone_d3_html( .attr("fill", function(d) { return d.color || "#666"; }) .attr("stroke", function(d) { return d.stroke || "#fff"; }); - node.append("title").text(function(d) { - return (d.label || d.id) + " — degree " + (d.degree || 0); - }); + node.append("title").text(function(d) { return nodeTooltip(d); }); node.filter(function(d) { return String(d.id).length <= 12; }) .append("text") diff --git a/topoexplorer/graph_metrics.py b/topoexplorer/graph_metrics.py new file mode 100644 index 00000000..33b8fff5 --- /dev/null +++ b/topoexplorer/graph_metrics.py @@ -0,0 +1,363 @@ +"""Graph metrics for the Neighborhood Explorer (displayed-graph only). + +All metrics are computed on the *displayed* NetworkX graph (after sampling and +lifting), on an undirected view so the same code path covers adjacency, +incidence/bipartite, and layered views. + +Metric tiers +------------ +- cheap (always): counts, density, degree stats, components, clustering, + transitivity, per-node degree centrality, per-edge Forman-Ricci. +- default (size-gated, largest connected component): diameter, radius. +- advanced (opt-in, expensive): betweenness, closeness, eccentricity, + eigenvector, pagerank, edge betweenness, degree assortativity. + +Forman-Ricci uses the Weber-Jost-Saucan (2018) definition; with unit weights +(the displayed graph carries none) it reduces exactly to ``4 - deg(u) - deg(v)`` +(no triangle term), matching TopoBench's implementation. +""" + +from __future__ import annotations + +from typing import Any + +import networkx as nx + +# Above these sizes, distance measures (diameter/radius) are skipped by default +# and betweenness is approximated via k-sampling. +DISTANCE_NODE_LIMIT = 1200 +DISTANCE_EDGE_LIMIT = 6000 +BETWEENNESS_EXACT_LIMIT = 800 +BETWEENNESS_SAMPLE_K = 500 + + +def _edge_key(u, v): + """Order-independent key for an undirected edge (string node ids).""" + a, b = str(u), str(v) + return (a, b) if a <= b else (b, a) + + +def forman_ricci_unweighted(UG: nx.Graph, u, v) -> int: + """Weber-Jost-Saucan Forman-Ricci with unit weights: ``4 - deg(u) - deg(v)``.""" + return 4 - UG.degree(u) - UG.degree(v) + + +def compute_graph_metrics( + G: nx.Graph, + *, + view_label: str = "graph", + expensive: bool = False, + betweenness_k: int | None = None, +) -> dict[str, Any]: + """Compute displayed-graph metrics. + + Parameters + ---------- + G : networkx graph + The displayed graph (may be directed for incidence/bipartite views). + view_label : str + Human-readable view name (e.g. ``"adjacency"``) for UI captions. + expensive : bool + When True, also compute advanced per-node/edge centralities. + betweenness_k : int, optional + Sample size for approximate betweenness; defaults to an internal policy. + + Returns + ------- + dict + ``{"graph": {...}, "graph_scope": {...}, "nodes": {...}, "edges": {...}, + "flags": {...}}``. + """ + notes: list[str] = [] + directed = G.is_directed() + UG = G.to_undirected() if directed else G + if directed: + notes.append(f"Computed on an undirected view of the {view_label} graph.") + + n = UG.number_of_nodes() + m = UG.number_of_edges() + degrees = dict(UG.degree()) + deg_values = list(degrees.values()) if degrees else [0] + + graph: dict[str, Any] = { + "nodes": n, + "edges": m, + "density": nx.density(UG) if n > 1 else 0.0, + "avg_degree": (sum(deg_values) / n) if n else 0.0, + "max_degree": max(deg_values), + "min_degree": min(deg_values), + } + + components = list(nx.connected_components(UG)) + graph["components"] = len(components) + largest_cc_nodes = max(components, key=len) if components else set() + graph["largest_cc_size"] = len(largest_cc_nodes) + + try: + graph["avg_clustering"] = nx.average_clustering(UG) if n else 0.0 + except Exception: + graph["avg_clustering"] = None + try: + graph["transitivity"] = nx.transitivity(UG) if n else 0.0 + except Exception: + graph["transitivity"] = None + + graph_scope: dict[str, str] = {} + + # Default tier: distance measures on the largest connected component. + graph["diameter"] = None + graph["radius"] = None + distance_skipped = n > DISTANCE_NODE_LIMIT or m > DISTANCE_EDGE_LIMIT + if distance_skipped: + notes.append( + "Diameter/radius skipped (graph too large); enable advanced metrics " + "or reduce sampling." + ) + elif largest_cc_nodes: + cc = UG.subgraph(largest_cc_nodes) + try: + ecc = nx.eccentricity(cc) + graph["diameter"] = max(ecc.values()) if ecc else None + graph["radius"] = min(ecc.values()) if ecc else None + if graph["components"] > 1: + graph_scope["diameter"] = "largest CC" + graph_scope["radius"] = "largest CC" + except Exception: + notes.append("Diameter/radius unavailable for this graph.") + + # Per-node cheap metrics. + nodes: dict[str, dict[str, Any]] = {} + deg_centrality = nx.degree_centrality(UG) if n else {} + try: + clustering = nx.clustering(UG) + except Exception: + clustering = {} + for node in UG.nodes(): + key = str(node) + nodes[key] = { + "degree": degrees.get(node, 0), + "degree_centrality": deg_centrality.get(node), + "clustering": clustering.get(node), + } + + # Per-edge cheap metrics (Forman-Ricci). + edges: dict[tuple, dict[str, Any]] = {} + for u, v in UG.edges(): + edges[_edge_key(u, v)] = {"forman_ricci": forman_ricci_unweighted(UG, u, v)} + + flags: dict[str, Any] = { + "view_label": view_label, + "expensive": bool(expensive), + "betweenness_approx": False, + "distance_skipped": distance_skipped, + "notes": notes, + } + + if expensive and n: + _add_advanced_metrics( + UG, nodes, edges, graph, flags, betweenness_k=betweenness_k + ) + + return { + "graph": graph, + "graph_scope": graph_scope, + "nodes": nodes, + "edges": edges, + "flags": flags, + } + + +def _add_advanced_metrics(UG, nodes, edges, graph, flags, *, betweenness_k=None): + """Augment node/edge dicts with expensive centralities (in place).""" + n = UG.number_of_nodes() + + # Betweenness: exact for small graphs, k-sampled otherwise. + if betweenness_k is None: + betweenness_k = ( + None if n <= BETWEENNESS_EXACT_LIMIT else min(n, BETWEENNESS_SAMPLE_K) + ) + try: + if betweenness_k is not None and betweenness_k < n: + betw = nx.betweenness_centrality(UG, k=betweenness_k, seed=0) + flags["betweenness_approx"] = True + flags["notes"].append( + f"Betweenness approximated with k={betweenness_k} samples." + ) + else: + betw = nx.betweenness_centrality(UG) + except Exception: + betw = {} + + try: + ebetw = nx.edge_betweenness_centrality(UG) + except Exception: + ebetw = {} + + try: + closeness = nx.closeness_centrality(UG) + except Exception: + closeness = {} + + # Eccentricity per node on the largest CC (finite only there). + eccentricity: dict[Any, Any] = {} + try: + components = list(nx.connected_components(UG)) + if components: + cc_nodes = max(components, key=len) + eccentricity = nx.eccentricity(UG.subgraph(cc_nodes)) + except Exception: + eccentricity = {} + + try: + eigenvector = nx.eigenvector_centrality(UG, max_iter=500, tol=1e-04) + except Exception: + eigenvector = {} + flags["notes"].append("Eigenvector centrality did not converge.") + + try: + pagerank = nx.pagerank(UG) + except Exception: + pagerank = {} + + for node in UG.nodes(): + key = str(node) + slot = nodes.setdefault(key, {}) + slot["betweenness"] = betw.get(node) + slot["closeness"] = closeness.get(node) + slot["eccentricity"] = eccentricity.get(node) + slot["eigenvector"] = eigenvector.get(node) + slot["pagerank"] = pagerank.get(node) + + for (u, v), val in ebetw.items(): + edges.setdefault(_edge_key(u, v), {})["edge_betweenness"] = val + + try: + graph["assortativity"] = nx.degree_assortativity_coefficient(UG) + except Exception: + graph["assortativity"] = None + + +# --------------------------------------------------------------------------- +# Formatting helpers (kept here so the app stays lean and the format is shared) +# --------------------------------------------------------------------------- + +def fmt_value(v: Any) -> str: + """Compact human-readable formatting for metric values.""" + if v is None: + return "N/A" + if isinstance(v, bool): + return str(v) + if isinstance(v, int): + return f"{v:,}" + if isinstance(v, float): + if v == 0: + return "0" + if abs(v) >= 1000 or abs(v) < 0.001: + return f"{v:.2e}" + return f"{v:.3f}".rstrip("0").rstrip(".") + return str(v) + + +# Labels for whole-graph metrics, in display order. +GRAPH_METRIC_LABELS = [ + ("nodes", "Nodes"), + ("edges", "Edges"), + ("density", "Density"), + ("avg_degree", "Avg degree"), + ("max_degree", "Max degree"), + ("min_degree", "Min degree"), + ("components", "Components"), + ("largest_cc_size", "Largest CC size"), + ("avg_clustering", "Avg clustering"), + ("transitivity", "Transitivity"), + ("diameter", "Diameter"), + ("radius", "Radius"), + ("assortativity", "Degree assortativity"), +] + +# Compact subset for the floating HUD overlay. +HUD_METRIC_KEYS = ["nodes", "edges", "density", "avg_degree", "components", "diameter"] + +NODE_TOOLTIP_LABELS = [ + ("degree_centrality", "deg cent"), + ("clustering", "clustering"), + ("betweenness", "betw"), + ("closeness", "closeness"), + ("eccentricity", "ecc"), +] + +EDGE_TOOLTIP_LABELS = [ + ("forman_ricci", "Forman"), + ("edge_betweenness", "edge betw"), +] + + +def build_hud_rows(metrics: dict) -> list[dict]: + """Compact ``[{label, value}]`` rows for the canvas HUD overlay.""" + graph = metrics.get("graph", {}) + scope = metrics.get("graph_scope", {}) + label_map = dict(GRAPH_METRIC_LABELS) + rows = [] + for key in HUD_METRIC_KEYS: + if key not in graph or graph.get(key) is None: + continue + label = label_map.get(key, key) + if key in scope: + label = f"{label} ({scope[key]})" + rows.append({"label": label, "value": fmt_value(graph[key])}) + return rows + + +def node_payload_metrics(metrics: dict, node_id: str) -> dict: + """Formatted per-node metric strings for a D3 node tooltip.""" + nd = (metrics.get("nodes") or {}).get(str(node_id)) + if not nd: + return {} + out = {} + for key, label in NODE_TOOLTIP_LABELS: + if key in nd and nd[key] is not None: + out[label] = fmt_value(nd[key]) + return out + + +def edge_payload_metrics(metrics: dict, u: str, v: str) -> dict: + """Formatted per-edge metric strings for a D3 edge tooltip.""" + ed = (metrics.get("edges") or {}).get(_edge_key(u, v)) + if not ed: + return {} + out = {} + for key, label in EDGE_TOOLTIP_LABELS: + if key in ed and ed[key] is not None: + out[label] = fmt_value(ed[key]) + return out + + +def summarize_metrics(metrics: dict, centrality: str = "degree_centrality", + top_k: int = 5) -> dict: + """Top-k nodes by a centrality and Forman-Ricci extremes for the Explore tab.""" + nodes = metrics.get("nodes") or {} + edges = metrics.get("edges") or {} + + ranked = [ + (nid, vals.get(centrality)) + for nid, vals in nodes.items() + if vals.get(centrality) is not None + ] + ranked.sort(key=lambda kv: kv[1], reverse=True) + top_nodes = ranked[:top_k] + + forman = [ + (key, vals["forman_ricci"]) + for key, vals in edges.items() + if vals.get("forman_ricci") is not None + ] + forman.sort(key=lambda kv: kv[1]) + most_negative = forman[:top_k] + most_positive = list(reversed(forman[-top_k:])) if forman else [] + + return { + "centrality": centrality, + "top_nodes": top_nodes, + "forman_most_negative": most_negative, + "forman_most_positive": most_positive, + } diff --git a/topoexplorer/neighborhood_explorer_app.py b/topoexplorer/neighborhood_explorer_app.py index 7c212865..66a56a17 100644 --- a/topoexplorer/neighborhood_explorer_app.py +++ b/topoexplorer/neighborhood_explorer_app.py @@ -30,6 +30,7 @@ import torch from d3_graph_html import build_standalone_d3_html +import graph_metrics as gm from omegaconf import OmegaConf from torch_geometric.utils import to_undirected import rootutils @@ -178,8 +179,9 @@ def extract_dataset_metadata(domain, dataset_name): initial_sidebar_state="expanded" ) -# ~2× Streamlit’s default sidebar width (~256px in recent releases). -_SIDEBAR_TARGET_WIDTH_PX = 512 +# Fixed, slightly-wider sidebar. The native collapse arrow is kept as a +# "focus mode" so the user can momentarily hide controls for a full-width graph. +_SIDEBAR_TARGET_WIDTH_PX = 560 st.markdown( f""" """, unsafe_allow_html=True, @@ -1654,26 +1693,6 @@ def launch_html_in_browser(path: Path) -> bool: return False -def open_d3_graph_window(payload): - """Write standalone HTML to a temp file and open in browser.""" - if payload is None: - st.warning("No graph to display.") - return - sel_ids = st.session_state.get("selected_neighborhood_ids") or [] - marker = "+".join(sel_ids) if sel_ids else None - html_doc = build_standalone_d3_html(payload, cache_marker=marker) - st.session_state["_d3_last_html"] = html_doc - path = Path(tempfile.gettempdir()) / f"topobench_graph_{uuid.uuid4().hex}.html" - path.write_text(html_doc, encoding="utf-8") - if launch_html_in_browser(path): - st.success("Opened graph in your browser.") - else: - st.warning( - "Could not launch a browser automatically. Use **Download last D3 graph** " - "below and open the file in Edge or Chrome (not VS Code preview)." - ) - - # ============================================================================ # Lifting Application # ============================================================================ @@ -2038,7 +2057,7 @@ def render_basic_lifting_editor(selected_lifting): if st.button( "Reset edited config to defaults", key=f"cfg_reset::{editor_id}", - use_container_width=True, + width="stretch", ): for k in BASIC_EDITABLE_KEYS: wk = f"editcfg::{editor_id}::{k}" @@ -2063,7 +2082,26 @@ def render_basic_lifting_editor(selected_lifting): # Streamlit App # ============================================================================ -D3_EMBED_HEIGHT = 760 +# Graph block height for ``components.html`` (header + chart inside the iframe). +D3_EMBED_HEIGHT = 820 + + +def _format_graph_header_title(vdesc, dataset_name, lift_line, num_nodes) -> str: + """Single-line title for the D3 embed header (View, Dataset, Lift, Nodes).""" + nodes_str = str(num_nodes) if num_nodes is not None else "N/A" + return ( + f"View: {vdesc} Dataset: {dataset_name} | {lift_line} | Nodes: {nodes_str}" + ) + + +def _node_count_from_dataset(dset0): + """Return node count for header display, or None if unavailable.""" + if dset0 is None or not hasattr(dset0, "num_nodes"): + return None + num_nodes = dset0.num_nodes + if num_nodes is None and getattr(dset0, "edge_index", None) is not None: + num_nodes = int(dset0.edge_index.max().item()) + 1 + return num_nodes def _render_dataset_metadata_card(domain, dataset_name): @@ -2093,19 +2131,38 @@ def _render_left_config(available_datasets): st.header("Data configuration") with st.expander("Dataset", expanded=True): + # Persist domain/dataset across sidebar tab switches via shadow keys + # (these selectboxes can't use a widget key because the Dataset options + # change with the Domain, which would invalidate a stored key value). + domain_options = list(available_datasets.keys()) + prev_domain = st.session_state.get("cfg_domain") + domain_index = ( + domain_options.index(prev_domain) + if prev_domain in domain_options + else 0 + ) selected_domain = st.selectbox( "Domain", - options=list(available_datasets.keys()), - index=0, + options=domain_options, + index=domain_index, help="Topological domain (folder under configs/dataset).", ) + st.session_state["cfg_domain"] = selected_domain + datasets_in_domain = available_datasets.get(selected_domain, []) + prev_dataset = st.session_state.get("cfg_dataset") + dataset_index = ( + datasets_in_domain.index(prev_dataset) + if prev_dataset in datasets_in_domain + else 0 + ) selected_dataset = st.selectbox( "Dataset", options=datasets_in_domain, - index=0, + index=dataset_index, help=f"YAML stem under configs/dataset/{selected_domain}/", ) + st.session_state["cfg_dataset"] = selected_dataset st.caption(f"**{selected_domain}** / **{selected_dataset}**") _render_dataset_metadata_card(selected_domain, selected_dataset) @@ -2177,21 +2234,11 @@ def _render_left_config(available_datasets): st.session_state["edited_lifting_config"] = edited_lifting_config st.session_state["edited_lifting_errors"] = edited_lifting_errors - st.toggle( - "3D layered view (orbit/zoom)", - help=( - "Stack ranks as horizontal planes in 3D (multi-incidence, 2+ adjacency, " - "or combined incidence+adjacency). Single-matrix views stay 2D." - ), - key="layered_3d_view", - on_change=_on_layered_3d_toggle, - ) - st.subheader("Actions") load_clicked = st.button( "Load graph", type="primary", - use_container_width=True, + width="stretch", key="load_graph_btn", ) @@ -2368,6 +2415,16 @@ def _on_layered_3d_toggle(): _rebuild_embed_for_neighborhoods(ids) +def _on_metrics_option_change(): + """Re-embed when a metrics toggle changes (HUD visibility / advanced compute).""" + if st.session_state.get("data") is None: + return + ids = list(st.session_state.get("selected_neighborhood_ids") or []) + if not ids: + return + _rebuild_embed_for_neighborhoods(ids) + + def _on_adjacency_toggle(): """Toggling an adjacency checkbox keeps the union of checked incidences and adjacencies.""" @@ -2406,6 +2463,66 @@ def _commit_selection(new_ids): _rebuild_embed_for_neighborhoods(new_ids) +def _metrics_marker(neigh_ids, expensive): + """Cache key for displayed-graph metrics. + + Includes everything that changes the displayed graph (selection, sampling + snapshot, graph index, lifting) plus the expensive flag. Deliberately + excludes the 2D/3D toggle and the HUD on/off toggle so those never trigger + a recompute. + """ + return "|".join( + [ + "+".join(neigh_ids), + f"max={st.session_state.get('_loaded_max_nodes')}", + f"min={st.session_state.get('_loaded_min_degree')}", + f"idx={st.session_state.get('active_graph_index')}", + f"lift={st.session_state.get('lifting_applied')}", + f"exp={bool(expensive)}", + ] + ) + + +def _compute_and_attach_metrics(payload, G_load, neigh_ids): + """Compute displayed-graph metrics (cached) and attach them to the payload.""" + if not isinstance(payload, dict): + return + + expensive = bool(st.session_state.get("metrics_expensive", False)) + marker = _metrics_marker(neigh_ids, expensive) + cached = st.session_state.get("_graph_metrics") + if cached and cached.get("marker") == marker: + metrics = cached.get("data") + else: + try: + metrics = gm.compute_graph_metrics( + G_load, + view_label=payload.get("graphType", "graph"), + expensive=expensive, + ) + st.session_state.pop("_graph_metrics_error", None) + except Exception as e: # never break the embed over metrics + metrics = None + st.session_state["_graph_metrics_error"] = str(e) + st.session_state["_graph_metrics"] = {"marker": marker, "data": metrics} + + if not metrics: + return + + payload["graphMetrics"] = gm.build_hud_rows(metrics) + payload["showMetricsHud"] = bool( + st.session_state.get("show_metrics_hud", True) + ) + for nd in payload.get("nodes", []): + nm = gm.node_payload_metrics(metrics, nd.get("id")) + if nm: + nd["metrics"] = nm + for ln in payload.get("links", []): + em = gm.edge_payload_metrics(metrics, ln.get("source"), ln.get("target")) + if em: + ln["metrics"] = em + + def _rebuild_embed_for_neighborhoods(neigh_ids): """Rebuild the embedded D3 view for a list of neighborhood ids. @@ -2731,6 +2848,16 @@ def _rebuild_embed_for_neighborhoods(neigh_ids): cache_marker = "+".join(neigh_ids) if st.session_state.get("layered_3d_view"): cache_marker += ":3d" + + dataset_name = st.session_state.get("dataset_name") or "—" + num_nodes = _node_count_from_dataset(dset0) + payload["title"] = _format_graph_header_title( + vdesc, dataset_name, lift_subtitle, num_nodes + ) + payload["subtitle"] = "" + + _compute_and_attach_metrics(payload, G_load, neigh_ids) + embed_html = build_standalone_d3_html( payload, embed=True, @@ -2754,8 +2881,18 @@ def _rebuild_embed_for_neighborhoods(neigh_ids): return True -def _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name): - """Populate neighborhoods, sampling, and embed for a working sample.""" +def _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name, + sync_widgets=True): + """Populate neighborhoods, sampling, and embed for a working sample. + + ``sync_widgets`` controls whether the sampling widget keys (``ui_min_degree``, + ``ui_rank_cap_*``, ``ui_hyperedge_cap``) are written. This is safe only on + the initial load or from on_change callbacks (before widgets are + re-instantiated). It must be ``False`` when called inline after those + widgets already exist in the current run (e.g. the "Apply to all" button), + otherwise Streamlit raises "cannot be modified after the widget is + instantiated". + """ rank_labels_for_payload = get_rank_labels( loaded_domain, dataset_name, dset0 ) @@ -2802,15 +2939,16 @@ def _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name): hyperedge_cap = max(0, min(int(hyperedge_cap), int(num_hyperedges))) min_degree = int(cfg.get("min_degree", 0)) - for rank, cap in caps_by_rank.items(): - st.session_state[f"ui_rank_cap_{rank}"] = int(cap) - if ( - hyperedge_cap is not None - and num_hyperedges is not None - and int(num_hyperedges) > 1 - ): - st.session_state["ui_hyperedge_cap"] = int(hyperedge_cap) - st.session_state["ui_min_degree"] = min_degree + if sync_widgets: + for rank, cap in caps_by_rank.items(): + st.session_state[f"ui_rank_cap_{rank}"] = int(cap) + if ( + hyperedge_cap is not None + and num_hyperedges is not None + and int(num_hyperedges) > 1 + ): + st.session_state["ui_hyperedge_cap"] = int(hyperedge_cap) + st.session_state["ui_min_degree"] = min_degree st.session_state["rank_populations"] = rank_pops st.session_state["hyperedge_population"] = num_hyperedges @@ -2994,7 +3132,7 @@ def _render_rank_cap_sliders(rank_populations, hyperedge_population): st.write("") if st.button( "Apply to all", - use_container_width=True, + width="stretch", key="ui_apply_all_rank_caps", ): for rank, pop in rank_populations.items(): @@ -3089,19 +3227,17 @@ def _render_inline_rank_cap(rank_populations, hyperedge_population): def _render_rank_cap_controls(rank_populations, hyperedge_population): - """Render rank-cap UI (inline or popover); return (ui_rank_caps, ui_hyperedge_cap).""" + """Render rank-cap UI (inline or expander); return (ui_rank_caps, ui_hyperedge_cap).""" if not rank_populations: return {}, None configurable = [ rank for rank, pop in rank_populations.items() if int(pop) >= 2 ] - use_popover = len(configurable) >= 2 + use_expander = len(configurable) >= 2 - if use_popover: - summary = _format_rank_cap_summary(rank_populations, hyperedge_population) - st.caption(summary) - with st.popover("Per-rank caps", use_container_width=True): + if use_expander: + with st.expander("Per-rank caps", expanded=True): return _render_rank_cap_sliders(rank_populations, hyperedge_population) return _render_inline_rank_cap(rank_populations, hyperedge_population) @@ -3174,7 +3310,9 @@ def _on_sampling_control_change(): dset0 = data[0] if hasattr(data, "__getitem__") else data loaded_domain = st.session_state.get("data_domain") dataset_name = st.session_state.get("dataset_name") - ok = _finalize_loaded_sample(dset0, cfg, loaded_domain, dataset_name) + ok = _finalize_loaded_sample( + dset0, cfg, loaded_domain, dataset_name, sync_widgets=False + ) if ok: snap = st.session_state.get("_load_cfg_snapshot") or {} snap["caps_by_rank"] = copy.deepcopy(cfg.get("caps_by_rank") or {}) @@ -3252,41 +3390,38 @@ def _render_graph_sample_section(): if st.session_state.get("data") is None: return - st.subheader("Graph sample") - _dset0, rank_populations, hyperedge_population, _is_loaded = ( _get_sampling_context() ) meta = st.session_state.get("dataset_metadata") or {} is_inductive = (meta.get("learning_setting") or "").lower() == "inductive" - with st.container(border=True): - col_left, col_right = st.columns([0.4, 0.6], gap="large") - - with col_left: - if is_inductive: - total = int(st.session_state.get("loaded_dataset_size", 1) or 1) - max_idx = max(0, total - 1) - active = int(st.session_state.get("active_graph_index", 0)) - err = st.session_state.get("_graph_index_error") - _render_graph_index_input(active, error=err) - st.caption( - f"**Available graphs:** `0` to `{max_idx}` ({total} total)" - ) - else: - st.caption("Single graph (transductive dataset)") - - st.slider( - "Minimum degree", - min_value=0, - max_value=20, - value=int(st.session_state.get("ui_min_degree", 0)), - key="ui_min_degree", - on_change=_on_sampling_control_change, - ) + if is_inductive: + total = int(st.session_state.get("loaded_dataset_size", 1) or 1) + max_idx = max(0, total - 1) + active = int(st.session_state.get("active_graph_index", 0)) + err = st.session_state.get("_graph_index_error") + _render_graph_index_input(active, error=err) + st.caption( + f"**Available graphs:** `0` to `{max_idx}` ({total} total)" + ) + else: + st.caption("Single graph (transductive dataset)") + + st.slider( + "Minimum degree", + min_value=0, + max_value=20, + value=int(st.session_state.get("ui_min_degree", 0)), + key="ui_min_degree", + help="Drop nodes below this degree after rank caps are applied.", + on_change=_on_sampling_control_change, + ) - with col_right: - _render_rank_cap_controls(rank_populations, hyperedge_population) + if rank_populations: + summary = _format_rank_cap_summary(rank_populations, hyperedge_population) + st.caption(summary) + _render_rank_cap_controls(rank_populations, hyperedge_population) def _do_load_graph(cfg, progress=None): @@ -3361,99 +3496,270 @@ def _do_load_graph(cfg, progress=None): ) -def _render_graph_view_rest(): - """Neighborhood picker and D3 embed (after graph sample section).""" +def _render_sidebar_tab_selector(data_loaded): + """Large three-tab selector backed by session state (survives reruns).""" + options = ["Load graph", "Explore", "Metrics"] + + # Apply a programmatic switch requested on a previous run (e.g. after a + # successful load). Must happen BEFORE the tab buttons are rendered. + pending = st.session_state.pop("_pending_tab", None) + if pending in options: + st.session_state["active_sidebar_tab"] = pending + if "active_sidebar_tab" not in st.session_state: + st.session_state["active_sidebar_tab"] = "Load graph" + + active = st.session_state.get("active_sidebar_tab") or "Load graph" + col_load, col_explore, col_metrics = st.columns(3) + with col_load: + if st.button( + "Load graph", + key="tab_load_btn", + type="primary" if active == "Load graph" else "secondary", + width="stretch", + ): + st.session_state["active_sidebar_tab"] = "Load graph" + st.rerun() + with col_explore: + if st.button( + "Explore", + key="tab_explore_btn", + type="primary" if active == "Explore" else "secondary", + width="stretch", + ): + st.session_state["active_sidebar_tab"] = "Explore" + st.rerun() + with col_metrics: + if st.button( + "Metrics", + key="tab_metrics_btn", + type="primary" if active == "Metrics" else "secondary", + width="stretch", + ): + st.session_state["active_sidebar_tab"] = "Metrics" + st.rerun() + + if not data_loaded: + st.caption("Load a graph to enable **Explore** and **Metrics**.") + + return active + + +def _render_explore_tab(): + """Post-load controls: 3D view, graph sample, and neighborhoods.""" + if st.session_state.get("data") is None: + st.info("Load a graph from the **Load graph** tab to explore neighborhoods.") + return + + st.toggle( + "3D layered view (orbit/zoom)", + help=( + "Stack ranks as horizontal planes in 3D (multi-incidence, 2+ adjacency, " + "or combined incidence+adjacency). Single-matrix views stay 2D." + ), + key="layered_3d_view", + on_change=_on_layered_3d_toggle, + ) + + st.divider() + st.subheader("Graph sample") + _render_graph_sample_section() + + st.divider() _render_neighborhood_picker() - embed_html = st.session_state.get("_d3_embed_html") - if not embed_html: - st.info( - "Use the **sidebar** to choose dataset and lifting options, then click " - "**Load graph**." - ) + +def _render_metrics_tab(): + """Displayed-graph metrics tab: HUD toggle, whole-graph table, top elements.""" + if st.session_state.get("data") is None: + st.info("Load a graph from the **Load graph** tab to see metrics.") return - data = st.session_state.get("data") - dset0 = ( - data[0] if (data is not None and hasattr(data, "__getitem__")) else data + err = st.session_state.get("_graph_metrics_error") + if err: + st.caption(f"Metrics unavailable: {err}") + return + + cached = st.session_state.get("_graph_metrics") or {} + metrics = cached.get("data") + if not metrics: + st.caption("Metrics will appear once a graph is displayed.") + return + + st.toggle( + "Show metrics overlay", + value=bool(st.session_state.get("show_metrics_hud", True)), + key="show_metrics_hud", + help="Compact whole-graph metrics in the top-left of the graph canvas.", + on_change=_on_metrics_option_change, ) - lifting_applied = st.session_state.get("lifting_applied") - if dset0 is not None and hasattr(dset0, "num_nodes"): - num_nodes = dset0.num_nodes - if ( - num_nodes is None - and getattr(dset0, "edge_index", None) is not None - ): - num_nodes = int(dset0.edge_index.max().item()) + 1 - num_features = ( - dset0.x.shape[1] - if ( - hasattr(dset0, "x") - and dset0.x is not None - and len(dset0.x.shape) > 1 - ) - else 0 + + flags = metrics.get("flags") or {} + for note in flags.get("notes", []): + st.caption(note) + + graph = metrics.get("graph") or {} + scope = metrics.get("graph_scope") or {} + label_map = dict(gm.GRAPH_METRIC_LABELS) + rows = [] + for key, _label in gm.GRAPH_METRIC_LABELS: + if key not in graph or graph.get(key) is None: + continue + label = label_map.get(key, key) + if key in scope: + label = f"{label} ({scope[key]})" + rows.append({"Metric": label, "Value": gm.fmt_value(graph[key])}) + if rows: + st.caption("Displayed graph (after sampling)") + st.dataframe(rows, hide_index=True, width="stretch") + + _render_metrics_top_elements(metrics) + + with st.expander("Advanced metrics"): + st.checkbox( + "Compute advanced metrics (betweenness, closeness, eccentricity, …)", + value=bool(st.session_state.get("metrics_expensive", False)), + key="metrics_expensive", + help=( + "Heavier centralities. May be slow or approximated on large " + "graphs; betweenness uses sampling above ~800 nodes." + ), + on_change=_on_metrics_option_change, ) - if lifting_applied: - st.success( - f"Viewing lifted data — **{lifting_applied['name']}** " - f"({lifting_applied['source']} → {lifting_applied['target']}) | " - f"Nodes: {num_nodes}, " - f"Node features: {num_features if num_features else 'N/A'}" - ) - else: - st.info( - f"Dataset loaded — Nodes: {num_nodes}, " - f"Node features: {num_features if num_features else 'N/A'}" + if flags.get("betweenness_approx"): + st.caption("Betweenness is approximated (k-sampled) for this graph.") + if not flags.get("expensive"): + st.caption("Enable to add per-node/edge centralities to tooltips and tables.") + + +def _render_metrics_top_elements(metrics): + """Top-k nodes by a chosen centrality plus Forman-Ricci extremes.""" + nodes = metrics.get("nodes") or {} + if not nodes: + return + + sample = next(iter(nodes.values()), {}) + centrality_opts = [ + (k, lbl) + for k, lbl in [ + ("degree_centrality", "Degree centrality"), + ("betweenness", "Betweenness"), + ("closeness", "Closeness"), + ("eigenvector", "Eigenvector"), + ("pagerank", "PageRank"), + ("clustering", "Clustering"), + ] + if k in sample and sample.get(k) is not None + ] + if centrality_opts: + keys = [k for k, _ in centrality_opts] + labels = {k: lbl for k, lbl in centrality_opts} + choice = st.selectbox( + "Top nodes by", + options=keys, + format_func=lambda k: labels.get(k, k), + key="metrics_top_centrality", + ) + summary = gm.summarize_metrics(metrics, centrality=choice, top_k=5) + top_rows = [ + {"Node": nid, labels.get(choice, choice): gm.fmt_value(val)} + for nid, val in summary["top_nodes"] + ] + if top_rows: + st.dataframe(top_rows, hide_index=True, width="stretch") + + edges = metrics.get("edges") or {} + if edges: + summary = gm.summarize_metrics(metrics, top_k=5) + neg = summary["forman_most_negative"] + if neg: + st.caption("Most negative Forman-Ricci edges") + st.dataframe( + [ + {"Edge": f"{u} — {v}", "Forman": gm.fmt_value(val)} + for (u, v), val in neg + ], + hide_index=True, + width="stretch", ) + +def _render_graph_canvas(): + """Main area: D3 embed (title + graph) and download directly below.""" + embed_html = st.session_state.get("_d3_embed_html") + if not embed_html: + return + sel_ids = list(st.session_state.get("selected_neighborhood_ids") or []) - avail = st.session_state.get("available_neighborhoods") or [] - if sel_ids: - labels = [next((n["label"] for n in avail if n["id"] == s), s) for s in sel_ids] - joined_ids = ", ".join(f"`{s}`" for s in sel_ids) - joined_labels = "; ".join(labels) - st.markdown(f"**Currently displayed:** {joined_ids} — {joined_labels}") - - # ``components.html`` has no ``key=`` in Streamlit 1.50; wrap in a keyed - # container so changing the neighborhood remounts the iframe subtree. neigh_key = "+".join(sel_ids) if sel_ids else "default" with st.container(key=f"d3_embed::{neigh_key}"): - components.html(embed_html, height=D3_EMBED_HEIGHT, scrolling=False) - - payload = st.session_state.get("_d3_payload") - if payload is not None and st.button( - "Open graph in new browser window", - key="d3_open_current", - use_container_width=True, - ): - open_d3_graph_window(payload) + st.iframe(embed_html, height=D3_EMBED_HEIGHT) last_html = st.session_state.get("_d3_last_html") if last_html: st.download_button( - label="Download last D3 graph (HTML)", + label="Download graph (HTML)", data=last_html, file_name="topobench_graph.html", mime="text/html", key="d3_download_main", - use_container_width=True, + width="stretch", ) +_PERSIST_WIDGET_KEYS = frozenset( + { + "use_lifting", + "ui_min_degree", + "ui_hyperedge_cap", + "ui_set_all_rank_caps", + "layered_3d_view", + "main_graph_index_input", + "show_metrics_hud", + "metrics_expensive", + } +) + + +def _persist_widget_state(): + """Keep sidebar widget selections alive across tab switches. + + Streamlit drops the session_state entry of any keyed widget that is not + rendered on a run (e.g. Explore/Metrics widgets while the Load graph tab is + showing). Re-assigning those keys here -- before any widget is instantiated + this run -- opts them out of that cleanup so selections survive when the + user switches sidebar tabs. + """ + for key in list(st.session_state.keys()): + if key in _PERSIST_WIDGET_KEYS or key.startswith("ui_rank_cap_"): + st.session_state[key] = st.session_state[key] + + def main(): + _persist_widget_state() + flash = st.session_state.pop("_flash_ok", None) if flash: st.success(flash) available_datasets = discover_available_datasets() + data_loaded = st.session_state.get("data") is not None + sidebar_cfg = None with st.sidebar: - sidebar_cfg = _render_left_config(available_datasets) - - st.header("Graph") + st.markdown( + '
TopoExplorer
', + unsafe_allow_html=True, + ) + tab = _render_sidebar_tab_selector(data_loaded) + st.divider() + if tab == "Load graph": + sidebar_cfg = _render_left_config(available_datasets) + elif tab == "Metrics": + _render_metrics_tab() + else: + _render_explore_tab() - if sidebar_cfg["load_clicked"]: + if sidebar_cfg is not None and sidebar_cfg["load_clicked"]: load_cfg = {**sidebar_cfg, **_default_load_sampling_cfg()} with st.status("Loading graph…", expanded=True) as status: def _progress(msg): @@ -3466,12 +3772,10 @@ def _progress(msg): else: status.update(label="Load stopped", state="error") if ok: + st.session_state["_pending_tab"] = "Explore" st.rerun() - if st.session_state.get("data") is not None: - _render_graph_sample_section() - - _render_graph_view_rest() + _render_graph_canvas() if __name__ == "__main__":