diff --git a/lib/components/SchematicViewer.tsx b/lib/components/SchematicViewer.tsx index ab4fd20..d3ad57f 100644 --- a/lib/components/SchematicViewer.tsx +++ b/lib/components/SchematicViewer.tsx @@ -6,6 +6,7 @@ import { su } from "@tscircuit/soup-util" import { useChangeSchematicComponentLocationsInSvg } from "lib/hooks/useChangeSchematicComponentLocationsInSvg" import { useChangeSchematicTracesForMovedComponents } from "lib/hooks/useChangeSchematicTracesForMovedComponents" import { useSchematicGroupsOverlay } from "lib/hooks/useSchematicGroupsOverlay" +import { useHighlightConnectedSchematicTraces } from "lib/hooks/useHighlightConnectedSchematicTraces" import { enableDebug } from "lib/utils/debug" import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { @@ -345,6 +346,12 @@ export const SchematicViewer = ({ editEvents: editEventsWithUnappliedEditEvents, }) + useHighlightConnectedSchematicTraces({ + svgDivRef, + circuitJson, + circuitJsonKey, + }) + // Add group overlays when enabled useSchematicGroupsOverlay({ svgDivRef, @@ -406,6 +413,10 @@ export const SchematicViewer = ({ {`[data-schematic-port-id]:hover { cursor: pointer !important; }`} )} +
{ + if (!(target instanceof Element)) return null + return target.closest( + '[data-circuit-json-type="schematic_trace"][data-schematic-trace-id]', + ) +} + +const getTraceId = (traceGroup: Element | null) => { + return traceGroup?.getAttribute("data-schematic-trace-id") ?? null +} + +class UnionFind { + parent = new Map() + + find(id: string): string { + if (!this.parent.has(id)) { + this.parent.set(id, id) + return id + } + + const parentId = this.parent.get(id)! + if (parentId === id) return id + + const root = this.find(parentId) + this.parent.set(id, root) + return root + } + + union(a: string, b: string) { + const rootA = this.find(a) + const rootB = this.find(b) + + if (rootA !== rootB) { + this.parent.set(rootB, rootA) + } + } +} + +export const useHighlightConnectedSchematicTraces = ({ + svgDivRef, + circuitJson, + circuitJsonKey, +}: { + svgDivRef: React.RefObject + circuitJson: CircuitJson + circuitJsonKey: string +}) => { + const highlightedTraceIdsRef = useRef>(new Set()) + + const connectedTraceIdsByTraceId = useMemo(() => { + const traceIdsByGroupId = new Map>() + const groupIdByTraceId = new Map() + + try { + const unionFind = new UnionFind() + const sourceTraces = su(circuitJson).source_trace.list() + + for (const sourceTrace of sourceTraces) { + const sourceTraceId = sourceTrace.source_trace_id + if (!sourceTraceId) continue + + const sourceTraceNodeId = `source_trace:${sourceTraceId}` + unionFind.find(sourceTraceNodeId) + + const connectedSourcePortIds = + sourceTrace.connected_source_port_ids ?? [] + const connectedSourceNetIds = sourceTrace.connected_source_net_ids ?? [] + + for (const sourcePortId of connectedSourcePortIds) { + unionFind.union(sourceTraceNodeId, `source_port:${sourcePortId}`) + } + + for (const sourceNetId of connectedSourceNetIds) { + unionFind.union(sourceTraceNodeId, `source_net:${sourceNetId}`) + } + } + + const schematicTraces = su(circuitJson).schematic_trace.list() + + for (const schematicTrace of schematicTraces) { + const schematicTraceId = schematicTrace.schematic_trace_id + if (!schematicTraceId) continue + + const groupId = schematicTrace.source_trace_id + ? unionFind.find(`source_trace:${schematicTrace.source_trace_id}`) + : `schematic_trace:${schematicTraceId}` + + groupIdByTraceId.set(schematicTraceId, groupId) + + const traceIds = traceIdsByGroupId.get(groupId) ?? new Set() + traceIds.add(schematicTraceId) + traceIdsByGroupId.set(groupId, traceIds) + } + } catch (err) { + console.error("Failed to derive connected schematic traces", err) + } + + const connectedTraceIdsByTraceId = new Map>() + + for (const [traceId, groupId] of groupIdByTraceId) { + connectedTraceIdsByTraceId.set( + traceId, + traceIdsByGroupId.get(groupId) ?? new Set([traceId]), + ) + } + + return connectedTraceIdsByTraceId + }, [circuitJsonKey, circuitJson]) + + useEffect(() => { + const svg = svgDivRef.current + if (!svg) return + + const clearHighlightedTraces = () => { + if (highlightedTraceIdsRef.current.size === 0) return + + for (const traceId of highlightedTraceIdsRef.current) { + svg + .querySelector( + `[data-circuit-json-type="schematic_trace"][data-schematic-trace-id="${traceId}"]`, + ) + ?.classList.remove(HOVERED_TRACE_CLASS) + } + + highlightedTraceIdsRef.current.clear() + } + + const highlightConnectedTraces = (traceId: string | null) => { + clearHighlightedTraces() + + if (!traceId) return + + const connectedTraceIds = + connectedTraceIdsByTraceId.get(traceId) ?? new Set([traceId]) + + for (const connectedTraceId of connectedTraceIds) { + svg + .querySelector( + `[data-circuit-json-type="schematic_trace"][data-schematic-trace-id="${connectedTraceId}"]`, + ) + ?.classList.add(HOVERED_TRACE_CLASS) + } + + highlightedTraceIdsRef.current = new Set(connectedTraceIds) + } + + const handlePointerOver = (event: PointerEvent) => { + highlightConnectedTraces( + getTraceId(getTraceGroupFromTarget(event.target)), + ) + } + + const handlePointerOut = (event: PointerEvent) => { + const currentTraceGroup = getTraceGroupFromTarget(event.target) + const nextTraceGroup = getTraceGroupFromTarget(event.relatedTarget) + + if (currentTraceGroup && currentTraceGroup === nextTraceGroup) { + return + } + + if (!nextTraceGroup) { + clearHighlightedTraces() + return + } + + highlightConnectedTraces(getTraceId(nextTraceGroup)) + } + + svg.addEventListener("pointerover", handlePointerOver) + svg.addEventListener("pointerout", handlePointerOut) + svg.addEventListener("pointerleave", clearHighlightedTraces) + + return () => { + clearHighlightedTraces() + svg.removeEventListener("pointerover", handlePointerOver) + svg.removeEventListener("pointerout", handlePointerOut) + svg.removeEventListener("pointerleave", clearHighlightedTraces) + } + }, [svgDivRef, connectedTraceIdsByTraceId]) +}