diff --git a/packages/app/src/components/inference/ui/GPUGraph.tsx b/packages/app/src/components/inference/ui/GPUGraph.tsx index 5864089..6edf6fa 100644 --- a/packages/app/src/components/inference/ui/GPUGraph.tsx +++ b/packages/app/src/components/inference/ui/GPUGraph.tsx @@ -5,6 +5,7 @@ import * as d3 from 'd3'; import React, { useCallback, useEffect, useMemo, useRef } from 'react'; import { useTheme } from 'next-themes'; +import { GRADIENT_NUDGE_EVENT } from '@/components/gradient-label-nudge'; import { useInference } from '@/components/inference/InferenceContext'; import ChartLegend from '@/components/ui/chart-legend'; import { getModelSortIndex, HARDWARE_CONFIG } from '@/lib/constants'; @@ -12,7 +13,14 @@ import { generateGpuDateColors } from '@/lib/dynamic-colors'; import { formatNumber, getDisplayLabel, updateRepoUrl } from '@/lib/utils'; import { useThemeColors } from '@/hooks/useThemeColors'; import { D3Chart } from '@/lib/d3-chart/D3Chart'; -import type { D3ChartHandle, RenderContext, ZoomContext } from '@/lib/d3-chart/D3Chart/types'; +import type { + CustomLayerConfig, + D3ChartHandle, + LayerConfig, + RenderContext, + ZoomContext, +} from '@/lib/d3-chart/D3Chart/types'; +import type { ContinuousScale } from '@/lib/d3-chart/types'; import { applyHoverState, applyNormalState } from '@/lib/chart-rendering'; import { formatLargeNumber, logTickFormat } from '@/lib/chart-rendering'; import { @@ -30,6 +38,14 @@ import { generateGPUGraphTooltipContent, getPointLabel, } from '@/components/inference/utils/tooltipUtils'; +import { + type ParetoPointLabel, + computeParetoPointLabels, + computeGradientStops, + PARETO_LABEL_COLORS, + buildGradientColorMap, + getParetoLabel, +} from '@/components/inference/utils/paretoLabels'; const CHART_MARGIN = { top: 24, right: 10, bottom: 60, left: 60 }; @@ -57,6 +73,10 @@ const GPUGraph = React.memo( setUseAdvancedLabels, highContrast, setHighContrast, + showGradientLabels, + setShowGradientLabels, + showLineLabels, + setShowLineLabels, selectAllActiveDates, } = useInference(); const { resolvedTheme } = useTheme(); @@ -165,6 +185,34 @@ const GPUGraph = React.memo( return result; }, [groupedData, selectedYAxisMetric, chartDefinition]); + // Gradient label data + const allPointLabelsByKey = useMemo(() => { + const globalLabelColorMap = new Map(); + let globalColorIdx = 0; + const result: Record = {}; + Object.entries(rooflines).forEach(([key, rooflinePoints]) => { + if (rooflinePoints.length < 2) return; + rooflinePoints.forEach((pt) => { + const label = getParetoLabel(pt); + if (!globalLabelColorMap.has(label)) { + globalLabelColorMap.set( + label, + PARETO_LABEL_COLORS[globalColorIdx % PARETO_LABEL_COLORS.length], + ); + globalColorIdx++; + } + }); + result[key] = computeParetoPointLabels(rooflinePoints, globalLabelColorMap); + }); + return result; + }, [rooflines]); + + // Point → gradient color lookup + const gradientColorByPoint = useMemo( + () => buildGradientColorMap(allPointLabelsByKey), + [allPointLabelsByKey], + ); + const optimalPointKeys = useMemo(() => { const keys = new Set(); Object.values(rooflines).forEach((pts) => @@ -213,12 +261,16 @@ const GPUGraph = React.memo( // Color resolver for points/rooflines const getColor = useMemo(() => { return (d: InferenceData) => { + if (showGradientLabels) { + const gc = gradientColorByPoint.get(d); + if (gc) return gc; + } const graphIndex = allGraphs.findIndex( ({ date, hwKey }) => d.date === date && d.hwKey === hwKey, ); return graphIndex >= 0 ? allGraphs[graphIndex].color : '#6b7280'; }; - }, [allGraphs]); + }, [allGraphs, showGradientLabels, gradientColorByPoint]); const getRooflineColor = useMemo(() => { return (key: string) => { @@ -228,13 +280,6 @@ const GPUGraph = React.memo( }; }, [allGraphs]); - const isRooflineVisible = useMemo(() => { - return (key: string) => { - const graphId = key.split('_').slice(0, -1).join('_'); - return activeDates.has(graphId); - }; - }, [activeDates]); - // Dismiss tooltip when pinned point's combo is hidden useEffect(() => { const pp = chartRef.current?.getPinnedPoint() as InferenceData | null; @@ -264,6 +309,15 @@ const GPUGraph = React.memo( const series = key.slice(0, key.lastIndexOf('_')); return series === seriesId ? null : '0.15'; }); + root + .selectAll('.parallelism-label, .line-label') + .transition('legend-hover') + .duration(150) + .style('opacity', function () { + const series = (this as SVGGElement).getAttribute('data-series'); + if (!series) return 0; + return series === seriesId ? 1 : 0; + }); }, []); const handleLegendHoverEnd = useCallback(() => { @@ -276,8 +330,554 @@ const GPUGraph = React.memo( .transition('legend-hover') .duration(150) .style('opacity', null); + root + .selectAll('.parallelism-label, .line-label') + .transition('legend-hover') + .duration(150) + .style('opacity', null); }, []); + // Helper: parse "date_hwKey_precision" → series id "date_hwKey" + const parseSeriesId = (key: string) => { + const lastUnderscore = key.lastIndexOf('_'); + return key.slice(0, lastUnderscore); + }; + + // --- Layers --- + const gpuGraphLayers = useMemo((): LayerConfig[] => { + // ── Layer 0: Rooflines + gradient labels + line labels (custom) ── + const rooflineLayer: CustomLayerConfig = { + type: 'custom', + key: 'rooflines', + render: (zoomGroup, ctx) => { + const xScale = ctx.xScale as ContinuousScale; + const yScale = ctx.yScale as ContinuousScale; + const { defs } = ctx.layout; + + const lineGen = d3 + .line() + .x((d) => xScale(d.x)) + .y((d) => yScale(d.y)) + .curve(d3.curveMonotoneX); + + // Ensure rooflines layer exists before dot-groups + let rooflinesLayer = zoomGroup.select('.rooflines-layer'); + if (rooflinesLayer.empty()) { + const firstDotGroup = zoomGroup.select('.dot-group').node() as SVGGElement | null; + const node = document.createElementNS('http://www.w3.org/2000/svg', 'g'); + node.setAttribute('class', 'rooflines-layer'); + const parent = zoomGroup.node()!; + if (firstDotGroup) parent.insertBefore(node, firstDotGroup); + else parent.appendChild(node); + rooflinesLayer = d3.select(node); + } + + // Build roofline entries with gradient or solid stroke + type Entry = { + key: string; + seriesId: string; + precision: string; + points: InferenceData[]; + stroke: string; + visible: boolean; + }; + const entries: Entry[] = []; + const activeGradientIds = new Set(); + + Object.entries(rooflines).forEach(([key, pts]) => { + if (pts.length <= 1) return; + const seriesId = parseSeriesId(key); + const precision = key.split('_').pop()!; + const visible = activeDates.has(seriesId); + let stroke = getRooflineColor(key); + + if (showGradientLabels) { + const pointLabels = allPointLabelsByKey[key]; + if (pointLabels) { + const stops = computeGradientStops(pointLabels, xScale); + if (stops) { + const gid = `roofline-gradient-${chartId}-${key}`; + activeGradientIds.add(gid); + let gradient = defs.select(`#${CSS.escape(gid)}`); + if (gradient.empty()) gradient = defs.append('linearGradient').attr('id', gid); + gradient + .attr('gradientUnits', 'userSpaceOnUse') + .attr('x1', xScale(pts[0].x)) + .attr('y1', 0) + .attr('x2', xScale(pts[pts.length - 1].x)) + .attr('y2', 0); + gradient + .selectAll('stop') + .data(stops) + .join('stop') + .attr('offset', (s) => `${(s.offset * 100).toFixed(2)}%`) + .attr('stop-color', (s) => s.color); + stroke = `url(#${gid})`; + } + } + } + + entries.push({ key, seriesId, precision, points: pts, stroke, visible }); + }); + + // Remove stale gradients + defs.selectAll('linearGradient').each(function () { + const id = (this as SVGLinearGradientElement).id; + if (id.startsWith(`roofline-gradient-${chartId}-`) && !activeGradientIds.has(id)) { + d3.select(this).remove(); + } + }); + + // Data join for roofline paths + rooflinesLayer + .selectAll('.roofline-path') + .data(entries, (d) => d.key) + .join('path') + .attr('class', (d) => `roofline-path roofline-${d.key}`) + .attr('data-series', (d) => d.seriesId) + .attr('fill', 'none') + .attr('stroke', (d) => d.stroke) + .attr('stroke-width', 2.5) + .attr('d', (d) => lineGen(d.points)) + .style('transition', 'opacity 150ms ease') + .style('opacity', (d) => (d.visible ? 1 : 0)); + + // Parallelism labels + type LabelSeg = { + segKey: string; + seriesId: string; + label: string; + color: string; + x: number; + y: number; + visible: boolean; + }; + const labelSegments: LabelSeg[] = []; + + if (showGradientLabels) { + Object.entries(allPointLabelsByKey).forEach(([key, pointLabels]) => { + if (pointLabels.length < 2) return; + const seriesId = parseSeriesId(key); + const visible = activeDates.has(seriesId); + + const segments: { label: string; color: string; points: InferenceData[] }[] = []; + let cur = { + label: pointLabels[0].label, + color: pointLabels[0].color, + points: [pointLabels[0].point], + }; + for (let i = 1; i < pointLabels.length; i++) { + if (pointLabels[i].label === cur.label) { + cur.points.push(pointLabels[i].point); + } else { + segments.push(cur); + cur = { + label: pointLabels[i].label, + color: pointLabels[i].color, + points: [pointLabels[i].point], + }; + } + } + segments.push(cur); + + segments.forEach((seg, idx) => { + const midPt = seg.points[Math.floor(seg.points.length / 2)]; + labelSegments.push({ + segKey: `${key}-${idx}`, + seriesId, + label: seg.label, + color: seg.color, + x: xScale(midPt.x), + y: yScale(midPt.y) - 14, + visible, + }); + }); + }); + } + + zoomGroup + .selectAll('.parallelism-label') + .data(labelSegments, (d) => d.segKey) + .join( + (enter) => { + const g = enter + .append('g') + .attr('class', 'parallelism-label') + .style('pointer-events', 'none') + .attr('transform', (d) => `translate(${d.x},${d.y})`); + g.append('rect') + .attr('class', 'pl-bg') + .attr('rx', 4) + .attr('ry', 4) + .attr('opacity', 0.9); + g.append('text') + .attr('class', 'pl-text') + .attr('text-anchor', 'middle') + .attr('dominant-baseline', 'central') + .attr('fill', 'white') + .attr('font-size', '9px') + .attr('font-weight', '600'); + return g; + }, + (update) => update, + (exit) => exit.remove(), + ) + .attr('data-seg-key', (d) => d.segKey) + .attr('data-series', (d) => d.seriesId) + .attr('transform', (d) => `translate(${d.x},${d.y})`) + .style('opacity', (d) => (d.visible ? 1 : 0)) + .each(function (d) { + const g = d3.select(this); + const text = g.select('.pl-text').text(d.label); + const bbox = (text.node() as SVGTextElement).getBBox(); + const px = 4; + const py = 2; + g.select('.pl-bg') + .attr('x', bbox.x - px) + .attr('y', bbox.y - py) + .attr('width', bbox.width + px * 2) + .attr('height', bbox.height + py * 2) + .attr('fill', d.color); + }); + + // ── Line labels (GPU+date name along each roofline) ── + type LineLabel = { + key: string; + seriesId: string; + label: string; + color: string; + x: number; + y: number; + visible: boolean; + }; + const lineLabels: LineLabel[] = []; + + if (showLineLabels) { + const LABEL_H = 18; + const LABEL_W = 120; + const placed: { x: number; y: number }[] = []; + const collides = (cx: number, cy: number) => + placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W); + + // Pick the roofline with most points per seriesId + const bestBySeries = new Map(); + for (const e of entries) { + if (!e.visible || e.points.length < 2) continue; + const prev = bestBySeries.get(e.seriesId); + if (!prev || e.points.length > prev.points.length) bestBySeries.set(e.seriesId, e); + } + + const sorted = [...bestBySeries.values()].sort((a, b) => { + const ay = yScale(a.points[0].y); + const by = yScale(b.points[0].y); + return ay - by; + }); + + for (const entry of sorted) { + const pts = entry.points; + const candidates = [ + pts[Math.min(1, pts.length - 1)], + pts[Math.floor(pts.length / 2)], + pts[Math.max(0, Math.floor((pts.length * 2) / 3))], + pts[pts.length - 1], + ]; + + const hwKey = entry.seriesId.split('_').slice(1).join('_'); + const hwConfig = HARDWARE_CONFIG[hwKey]; + const label = hwConfig + ? `${getDisplayLabel(hwConfig)} ${entry.seriesId.split('_')[0]}` + : entry.seriesId; + let foundPlacement = false; + for (const pt of candidates) { + const px = xScale(pt.x); + const py = yScale(pt.y); + if (!collides(px, py)) { + lineLabels.push({ + key: entry.key, + seriesId: entry.seriesId, + label, + color: entry.stroke.startsWith('url(') + ? getRooflineColor(entry.key) + : entry.stroke, + x: px, + y: py, + visible: true, + }); + placed.push({ x: px, y: py }); + foundPlacement = true; + break; + } + } + if (!foundPlacement) { + const pt = pts[0]; + lineLabels.push({ + key: entry.key, + seriesId: entry.seriesId, + label, + color: entry.stroke.startsWith('url(') + ? getRooflineColor(entry.key) + : entry.stroke, + x: xScale(pt.x), + y: yScale(pt.y), + visible: false, + }); + } + } + + // Hidden entries for non-visible series + const labeledSeries = new Set(lineLabels.map((l) => l.seriesId)); + for (const entry of entries) { + if (entry.points.length >= 2 && !labeledSeries.has(entry.seriesId)) { + const hwKey = entry.seriesId.split('_').slice(1).join('_'); + const hwConfig = HARDWARE_CONFIG[hwKey]; + const label = hwConfig + ? `${getDisplayLabel(hwConfig)} ${entry.seriesId.split('_')[0]}` + : entry.seriesId; + lineLabels.push({ + key: entry.key, + seriesId: entry.seriesId, + label, + color: entry.stroke.startsWith('url(') + ? getRooflineColor(entry.key) + : entry.stroke, + x: xScale(entry.points[0].x), + y: yScale(entry.points[0].y), + visible: false, + }); + labeledSeries.add(entry.seriesId); + } + } + } + + zoomGroup + .selectAll('.line-label') + .data(lineLabels, (d) => d.key) + .join( + (enter) => { + const g = enter + .append('g') + .attr('class', 'line-label') + .style('pointer-events', 'none') + .attr('transform', (d) => `translate(${d.x},${d.y})`); + g.append('rect') + .attr('class', 'll-bg') + .attr('rx', 4) + .attr('ry', 4) + .attr('opacity', 0.95); + g.append('text') + .attr('class', 'll-text') + .attr('text-anchor', 'start') + .attr('dominant-baseline', 'central') + .attr('fill', 'white') + .attr('font-size', '10px') + .attr('font-weight', '600'); + return g; + }, + (update) => update, + (exit) => exit.remove(), + ) + .attr('data-line-key', (d) => d.key) + .attr('data-series', (d) => d.seriesId) + .attr('transform', (d) => `translate(${d.x + 8},${d.y - 14})`) + .style('opacity', (d) => (d.visible ? 1 : 0)) + .each(function (d) { + const g = d3.select(this); + const text = g.select('.ll-text').text(d.label); + const bbox = (text.node() as SVGTextElement).getBBox(); + const px = 5; + const py = 3; + g.select('.ll-bg') + .attr('x', bbox.x - px) + .attr('y', bbox.y - py) + .attr('width', bbox.width + px * 2) + .attr('height', bbox.height + py * 2) + .attr('fill', d.color); + }); + }, + onZoom: (zoomGroup, ctx) => { + const newXScale = ctx.newXScale as ContinuousScale; + const newYScale = ctx.newYScale as ContinuousScale; + const { defs } = ctx.layout; + + const lineGen = d3 + .line() + .x((d) => newXScale(d.x)) + .y((d) => newYScale(d.y)) + .curve(d3.curveMonotoneX); + + // Update roofline paths + Object.entries(rooflines).forEach(([key, pts]) => { + if (pts.length < 2) return; + const sel = zoomGroup.select(`.roofline-${key}`); + if (!sel.empty()) sel.attr('d', lineGen(pts) as string); + }); + + // Update gradient coordinates + if (showGradientLabels) { + Object.entries(allPointLabelsByKey).forEach(([key, pointLabels]) => { + if (pointLabels.length < 2) return; + const gid = `roofline-gradient-${chartId}-${key}`; + const gradientEl = defs.select(`#${CSS.escape(gid)}`); + if (!gradientEl.empty()) { + const newStops = computeGradientStops(pointLabels, newXScale); + if (newStops) { + gradientEl + .attr('x1', newXScale(pointLabels[0].point.x)) + .attr('x2', newXScale(pointLabels[pointLabels.length - 1].point.x)); + gradientEl + .selectAll('stop') + .data(newStops) + .join('stop') + .attr('offset', (s) => `${(s.offset * 100).toFixed(2)}%`) + .attr('stop-color', (s) => s.color); + } + } + + // Update parallelism label positions + const segments: { points: InferenceData[] }[] = []; + let cur = { points: [pointLabels[0].point] }; + for (let i = 1; i < pointLabels.length; i++) { + if (pointLabels[i].label === pointLabels[i - 1].label) { + cur.points.push(pointLabels[i].point); + } else { + segments.push(cur); + cur = { points: [pointLabels[i].point] }; + } + } + segments.push(cur); + + segments.forEach((seg, idx) => { + const segKey = `${key}-${idx}`; + const labelGroup = zoomGroup.select( + `.parallelism-label[data-seg-key="${segKey}"]`, + ); + if (!labelGroup.empty()) { + const midPt = seg.points[Math.floor(seg.points.length / 2)]; + labelGroup.attr( + 'transform', + `translate(${newXScale(midPt.x)},${newYScale(midPt.y) - 14})`, + ); + } + }); + }); + } + + // Update line label positions on zoom + if (showLineLabels) { + const LABEL_H = 18; + const LABEL_W = 120; + const placed: { x: number; y: number }[] = []; + const collides = (cx: number, cy: number) => + placed.some((p) => Math.abs(p.y - cy) < LABEL_H && Math.abs(p.x - cx) < LABEL_W); + + const bestBySeries = new Map(); + for (const [key, pts] of Object.entries(rooflines)) { + if (pts.length < 2) continue; + const seriesId = parseSeriesId(key); + if (!activeDates.has(seriesId)) continue; + const prev = bestBySeries.get(seriesId); + if (!prev || pts.length > prev[1].length) bestBySeries.set(seriesId, [key, pts]); + } + const visibleEntries = [...bestBySeries.values()].sort( + ([, a], [, b]) => newYScale(a[0].y) - newYScale(b[0].y), + ); + + const zoomResults = new Map(); + for (const [key, pts] of visibleEntries) { + const candidates = [ + pts[Math.min(1, pts.length - 1)], + pts[Math.floor(pts.length / 2)], + pts[Math.max(0, Math.floor((pts.length * 2) / 3))], + pts[pts.length - 1], + ]; + let found = false; + for (const pt of candidates) { + const px = newXScale(pt.x); + const py = newYScale(pt.y); + if (!collides(px, py)) { + zoomResults.set(key, { x: px, y: py, vis: true }); + placed.push({ x: px, y: py }); + found = true; + break; + } + } + if (!found) { + zoomResults.set(key, { + x: newXScale(pts[0].x), + y: newYScale(pts[0].y), + vis: false, + }); + } + } + + zoomGroup.selectAll('.line-label').each(function () { + const el = d3.select(this); + const k = el.attr('data-line-key'); + const zl = zoomResults.get(k); + if (zl) { + el.attr('transform', `translate(${zl.x + 8},${zl.y - 14})`); + el.style('opacity', zl.vis ? 1 : 0); + } else { + el.style('opacity', 0); + } + }); + } + }, + }; + + // ── Layer 1: Scatter points ── + const scatterLayer: LayerConfig = { + type: 'scatter', + key: 'points', + data: filteredData, + config: { + getColor, + hideLabels: hidePointLabels || showGradientLabels, + getLabelText: (d) => (useAdvancedLabels ? getPointLabel(d) : String(d.tp)), + foreground: 'var(--foreground)', + dataAttrs: { + series: (d) => `${d.date}_${d.hwKey}`, + }, + }, + }; + + return [rooflineLayer, scatterLayer]; + }, [ + rooflines, + allPointLabelsByKey, + showGradientLabels, + showLineLabels, + activeDates, + chartId, + getRooflineColor, + filteredData, + getColor, + hidePointLabels, + useAdvancedLabels, + ]); + + // --- Zoom config --- + const gpuGraphZoom = useMemo( + () => ({ + enabled: true, + axes: 'both' as const, + scaleExtent: [1, 20] as [number, number], + resetEventName: `gpu_timeseries_zoom_reset_${chartId}`, + onReset: () => { + track('interactivity_zoom_reset'); + }, + onZoom: (_event: d3.D3ZoomEvent, ctx: ZoomContext) => { + if (logScale) { + const newYScale = ctx.newYScale as d3.ScaleLogarithmic; + ctx.layout.yAxisGroup.call( + d3.axisLeft(newYScale).ticks(10).tickFormat(logTickFormat(newYScale)) as any, + ); + } + }, + }), + [chartId, logScale], + ); + if (data.length === 0) { return (
@@ -329,48 +929,8 @@ const GPUGraph = React.memo( tickFormat: logScale ? undefined : (d) => formatLargeNumber(d as number), tickCount: 10, }} - layers={[ - { - type: 'roofline', - key: 'rooflines', - rooflines: rooflines as Record, - config: { - getColor: getRooflineColor, - isVisible: isRooflineVisible, - }, - }, - { - type: 'scatter', - key: 'points', - data: filteredData, - config: { - getColor, - hideLabels: hidePointLabels, - getLabelText: (d) => (useAdvancedLabels ? getPointLabel(d) : String(d.tp)), - foreground: 'var(--foreground)', - dataAttrs: { - series: (d) => `${d.date}_${d.hwKey}`, - }, - }, - }, - ]} - zoom={{ - enabled: true, - axes: 'both', - scaleExtent: [1, 20], - resetEventName: `gpu_timeseries_zoom_reset_${chartId}`, - onReset: () => { - track('interactivity_zoom_reset'); - }, - onZoom: (_event, ctx: ZoomContext) => { - if (logScale) { - const newYScale = ctx.newYScale as d3.ScaleLogarithmic; - ctx.layout.yAxisGroup.call( - d3.axisLeft(newYScale).ticks(10).tickFormat(logTickFormat(newYScale)) as any, - ); - } - }, - }} + layers={gpuGraphLayers} + zoom={gpuGraphZoom} tooltip={{ rulerType: 'crosshair', content: (d: InferenceData, isPinned: boolean) => @@ -473,6 +1033,40 @@ const GPUGraph = React.memo( onCheckedChange: (c) => { setUseAdvancedLabels(c); track('interactivity_advanced_labels_toggled', { enabled: c }); + if (c && !showGradientLabels) { + window.dispatchEvent( + new CustomEvent(GRADIENT_NUDGE_EVENT, { + detail: { + enableGradient: () => { + setShowGradientLabels(true); + setUseAdvancedLabels(false); + track('interactivity_gradient_labels_toggled', { + enabled: true, + source: 'nudge', + }); + }, + }, + }), + ); + } + }, + }, + { + id: 'gpu-gradient-labels', + label: 'Gradient Labels', + checked: showGradientLabels, + onCheckedChange: (c) => { + setShowGradientLabels(c); + track('interactivity_gradient_labels_toggled', { enabled: c }); + }, + }, + { + id: 'gpu-line-labels', + label: 'Line Labels', + checked: showLineLabels, + onCheckedChange: (c) => { + setShowLineLabels(c); + track('interactivity_line_labels_toggled', { enabled: c }); }, }, ]}