`),
}),
[spec.yAxisLabel],
);
@@ -164,7 +171,7 @@ function ScatterChart({
content: (d) => {
const hwKey = d.hwKey ?? '';
const color = colorMap[hwKey] ?? '#888';
- return `
+ return sanitize(`
${hwKey}
@@ -175,7 +182,7 @@ function ScatterChart({
${spec.yAxisLabel}: ${d.y.toLocaleString(undefined, { maximumFractionDigits: 2 })}
-
`;
+
`);
},
}),
[colorMap, spec.yAxisLabel],
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index e971ee0..4c15a14 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -102,6 +102,9 @@ importers:
d3:
specifier: ^7.9.0
version: 7.9.0
+ dompurify:
+ specifier: ^3.3.3
+ version: 3.3.3
gray-matter:
specifier: ^4.0.3
version: 4.0.3
From bab5d4108a8c26ec31ca7a2b55125830e99e9b94 Mon Sep 17 00:00:00 2001
From: adibarra <93070681+adibarra@users.noreply.github.com>
Date: Mon, 30 Mar 2026 22:01:45 -0500
Subject: [PATCH 14/14] feat: validate LLM specs, multi-chart support, chart
quality improvements
- Validate all LLM output fields against enum whitelists before use
- Sanitize error messages to prevent API key leaks
- Support up to 2 charts for cross-model comparison queries
- Smarter chart type selection guidance in system prompt
- Stricter DOMPurify config (whitelist tags/attrs)
- Dynamic bar chart height based on data count
- Scatter chart: zoom scaleExtent, grabCursor, instructions
- Bar chart: instructions overlay
---
.../components/ai-chart/AiChartDisplay.tsx | 10 +-
.../src/components/ai-chart/AiChartResult.tsx | 58 ++--
.../components/ai-chart/example-prompts.ts | 4 +-
.../components/ai-chart/prompt-templates.ts | 52 +++-
packages/app/src/components/ai-chart/types.ts | 71 +++++
packages/app/src/hooks/api/use-ai-chart.ts | 280 +++++++++---------
packages/app/src/lib/ai-providers.ts | 7 +-
7 files changed, 285 insertions(+), 197 deletions(-)
diff --git a/packages/app/src/components/ai-chart/AiChartDisplay.tsx b/packages/app/src/components/ai-chart/AiChartDisplay.tsx
index 33af57c..c1d30c1 100644
--- a/packages/app/src/components/ai-chart/AiChartDisplay.tsx
+++ b/packages/app/src/components/ai-chart/AiChartDisplay.tsx
@@ -166,15 +166,7 @@ export default function AiChartDisplay() {
)}
{/* Result */}
- {result && (
-
- )}
+ {result &&
}
{/* Example prompts (shown when no result) */}
{!result && !isLoading && !error && (
diff --git a/packages/app/src/components/ai-chart/AiChartResult.tsx b/packages/app/src/components/ai-chart/AiChartResult.tsx
index e31e812..87e2639 100644
--- a/packages/app/src/components/ai-chart/AiChartResult.tsx
+++ b/packages/app/src/components/ai-chart/AiChartResult.tsx
@@ -16,17 +16,20 @@ import type {
import DOMPurify from 'dompurify';
import type { AiChartBarPoint, AiChartSpec } from './types';
+import type { AiSingleChartResult } from '@/hooks/api/use-ai-chart';
/** Sanitize tooltip HTML that may contain LLM-generated strings. */
function sanitize(html: string): string {
- return DOMPurify.sanitize(html);
+ return DOMPurify.sanitize(html, {
+ ALLOWED_TAGS: ['div', 'span', 'strong', 'br'],
+ ALLOWED_ATTR: ['style'],
+ });
}
+const CHART_INSTRUCTIONS = 'Hover for details';
+
interface AiChartResultProps {
- spec: AiChartSpec;
- barData: AiChartBarPoint[];
- scatterData: InferenceData[];
- colorMap: Record
;
+ charts: AiSingleChartResult[];
summary: string | null;
}
@@ -109,7 +112,7 @@ function BarChart({ data, spec }: { data: AiChartBarPoint[]; spec: AiChartSpec }
);
}
@@ -200,34 +204,32 @@ function ScatterChart({
layers={layers}
tooltip={tooltip}
watermark="logo"
- zoom={{ enabled: true, axes: 'both' }}
+ grabCursor
+ instructions={`${CHART_INSTRUCTIONS} • Scroll to zoom • Drag to pan`}
+ zoom={{ enabled: true, axes: 'both', scaleExtent: [0.7, 20] }}
/>
);
}
-export default function AiChartResult({
- spec,
- barData,
- scatterData,
- colorMap,
- summary,
-}: AiChartResultProps) {
+export default function AiChartResult({ charts, summary }: AiChartResultProps) {
return (
-
-
- {spec.title}
- {spec.description}
-
-
- {spec.chartType === 'bar' && barData.length > 0 && (
-
- )}
- {spec.chartType === 'scatter' && scatterData.length > 0 && (
-
- )}
-
-
+ {charts.map((chart, i) => (
+
+
+ {chart.spec.title}
+ {chart.spec.description}
+
+
+ {chart.spec.chartType === 'bar' && chart.barData.length > 0 && (
+
+ )}
+ {chart.spec.chartType === 'scatter' && chart.scatterData.length > 0 && (
+
+ )}
+
+
+ ))}
{summary && (
diff --git a/packages/app/src/components/ai-chart/example-prompts.ts b/packages/app/src/components/ai-chart/example-prompts.ts
index 5fc88cf..ece8c93 100644
--- a/packages/app/src/components/ai-chart/example-prompts.ts
+++ b/packages/app/src/components/ai-chart/example-prompts.ts
@@ -1,8 +1,8 @@
export const EXAMPLE_PROMPTS = [
'Compare throughput per GPU across all GPUs for DeepSeek R1 at 8k/1k',
'Bar chart: H100 vs B200 vs GB200 cost per million tokens (hyperscaler) for DeepSeek R1',
- 'Show a scatter plot of all GPU configs for DeepSeek R1 at 8k/1k with throughput per GPU',
+ 'Compare Kimi K2.5 vs DeepSeek R1 throughput per GPU at 8k/1k',
'Which GPU has the best GSM8K accuracy score for DeepSeek R1?',
'Compare reliability/success rate across all GPUs',
- 'Bar chart of energy per output token across all GPUs for gpt-oss at 8k/1k',
+ 'Show a scatter plot of all GPU configs for DeepSeek R1 at 8k/1k with throughput per GPU',
];
diff --git a/packages/app/src/components/ai-chart/prompt-templates.ts b/packages/app/src/components/ai-chart/prompt-templates.ts
index b8ae3b7..23e7c19 100644
--- a/packages/app/src/components/ai-chart/prompt-templates.ts
+++ b/packages/app/src/components/ai-chart/prompt-templates.ts
@@ -1,5 +1,5 @@
/**
- * System prompt for the LLM that parses user natural language into an AiChartSpec.
+ * System prompt for the LLM that parses user natural language into AiChartSpec(s).
* Kept compact to minimize token cost.
*/
export function buildParsePrompt(): string {
@@ -39,23 +39,38 @@ Y-axis metrics for evaluations:
Y-axis metrics for reliability:
- reliability_rate → Success Rate (%)
-## Rules
+## Chart type selection rules
+
+Choose the chart type based on the user's intent:
+- **"bar"**: Use for comparing a single metric across GPUs/configs at a fixed operating point. Best for "compare X vs Y", "which GPU is best for...", "rank by...", direct comparisons. This is the DEFAULT for most queries.
+- **"scatter"**: Use ONLY when the user explicitly wants to see the full performance curve (all data points), trade-off relationships, or Pareto frontiers. Keywords: "scatter", "plot all points", "performance curve", "trade-off", "pareto".
+
+When in doubt, prefer "bar" — it produces cleaner, more readable charts.
+
+## Multi-chart comparisons
+
+If the user asks to compare two DIFFERENT models or two fundamentally different configurations side-by-side (e.g., "compare Kimi K2.5 vs DeepSeek R1" or "compare 1k/1k vs 8k/1k"), return an ARRAY of 2 chart specs — one for each. Each spec should have its own title clearly identifying what it shows.
+
+If the user is comparing GPUs/hardware within a single model (e.g., "H100 vs B200 for DeepSeek R1"), that's a single chart with multiple hardware keys — do NOT split into two charts.
+
+## General rules
1. Map user intent to the closest available values. Be flexible with naming (e.g., "H100" → "h100", "deepseek r1" → "DeepSeek-R1-0528").
2. Pick the correct dataSource based on what the user is asking about (performance → benchmarks, accuracy → evaluations, uptime/success → reliability, trends over time → history).
3. hardwareKeys: list of GPU base names to compare. Empty array [] means "all GPUs".
4. precisions: list of precisions. Empty array [] means "all precisions".
-5. chartType: "bar" for comparing specific values across GPUs/configs, "scatter" for plotting all data points.
-6. targetInteractivity: for benchmark bar charts, the interactivity level (tok/s/user) to read from. Default 40.
-7. If the user doesn't specify a model, default to "DeepSeek-R1-0528".
-8. If the user doesn't specify a sequence, default to "8k/1k".
-9. title: a short chart title describing the comparison.
-10. description: a one-sentence description of what the chart shows.
-11. For evaluations: yAxisMetric should be "eval_score". For reliability: yAxisMetric should be "reliability_rate".
+5. targetInteractivity: for benchmark bar charts, the interactivity level (tok/s/user) to read from. Default 40.
+6. If the user doesn't specify a model, default to "DeepSeek-R1-0528".
+7. If the user doesn't specify a sequence, default to "8k/1k".
+8. title: a short chart title describing the comparison.
+9. description: a one-sentence description of what the chart shows.
+10. For evaluations: yAxisMetric should be "eval_score". For reliability: yAxisMetric should be "reliability_rate".
## Output format
-Return ONLY valid JSON matching this schema (no markdown, no preamble):
+Return ONLY valid JSON (no markdown, no preamble).
+
+For a single chart, return one object:
{
"chartType": "bar" | "scatter",
"dataSource": "benchmarks" | "evaluations" | "reliability" | "history",
@@ -68,18 +83,25 @@ Return ONLY valid JSON matching this schema (no markdown, no preamble):
"targetInteractivity": number,
"title": "string",
"description": "string"
-}`;
+}
+
+For a comparison of two different models/configs, return an array of 2 objects:
+[{ ... }, { ... }]`;
}
export function buildSummaryPrompt(
- spec: { title: string; yAxisLabel: string; model: string; sequence: string },
+ specs: { title: string; yAxisLabel: string; model: string; sequence: string }[],
dataDescription: string,
): string {
+ const specSummary = specs
+ .map(
+ (s) => `Chart: ${s.title} | Metric: ${s.yAxisLabel} | Model: ${s.model}, Seq: ${s.sequence}`,
+ )
+ .join('\n');
+
return `You are an expert performance analyst. Based on the following benchmark data, provide a concise 2-3 sentence summary highlighting the key takeaway.
-Chart: ${spec.title}
-Metric: ${spec.yAxisLabel}
-Model: ${spec.model}, Sequence: ${spec.sequence}
+${specSummary}
Data:
${dataDescription}
diff --git a/packages/app/src/components/ai-chart/types.ts b/packages/app/src/components/ai-chart/types.ts
index 65c2d7c..7684346 100644
--- a/packages/app/src/components/ai-chart/types.ts
+++ b/packages/app/src/components/ai-chart/types.ts
@@ -1,3 +1,7 @@
+import { Model, Sequence, Precision } from '@/lib/data-mappings';
+import { Y_AXIS_METRICS } from '@/lib/chart-utils';
+import { MODEL_ORDER } from '@/lib/constants';
+
export type AiProvider = 'openai' | 'anthropic' | 'xai' | 'google';
export type AiChartType = 'bar' | 'scatter';
@@ -18,9 +22,76 @@ export interface AiChartSpec {
description: string;
}
+/** The LLM may return an array of up to 2 specs for comparison queries. */
+export type AiLlmResponse = AiChartSpec | AiChartSpec[];
+
export interface AiChartBarPoint {
hwKey: string;
label: string;
value: number;
color: string;
}
+
+// ---------------------------------------------------------------------------
+// Validation whitelists
+// ---------------------------------------------------------------------------
+
+const VALID_CHART_TYPES = new Set(['bar', 'scatter']);
+const VALID_DATA_SOURCES = new Set(['benchmarks', 'evaluations', 'reliability', 'history']);
+const VALID_MODELS = new Set(Object.values(Model));
+const VALID_SEQUENCES = new Set(Object.values(Sequence));
+const VALID_PRECISIONS = new Set(Object.values(Precision));
+const VALID_GPU_BASES = new Set(MODEL_ORDER);
+const VALID_Y_METRICS = new Set([...Y_AXIS_METRICS, 'eval_score', 'reliability_rate']);
+
+/** Validate and clamp an LLM-generated spec to known values. Throws on unrecoverable input. */
+export function validateSpec(raw: Record): AiChartSpec {
+ const chartType = VALID_CHART_TYPES.has(raw.chartType as string)
+ ? (raw.chartType as AiChartType)
+ : 'bar';
+
+ const dataSource = VALID_DATA_SOURCES.has(raw.dataSource as string)
+ ? (raw.dataSource as AiDataSource)
+ : 'benchmarks';
+
+ const model = VALID_MODELS.has(raw.model as string) ? (raw.model as string) : Model.DeepSeek_R1;
+
+ const sequence = VALID_SEQUENCES.has(raw.sequence as string)
+ ? (raw.sequence as string)
+ : Sequence.EightK_OneK;
+
+ const rawPrecisions = Array.isArray(raw.precisions) ? (raw.precisions as string[]) : [];
+ const precisions = rawPrecisions
+ .filter((p) => VALID_PRECISIONS.has(p.toLowerCase()))
+ .map((p) => p.toLowerCase());
+
+ const rawHwKeys = Array.isArray(raw.hardwareKeys) ? (raw.hardwareKeys as string[]) : [];
+ const hardwareKeys = rawHwKeys
+ .filter((k) => VALID_GPU_BASES.has(k.toLowerCase()))
+ .map((k) => k.toLowerCase());
+
+ const yAxisMetric = VALID_Y_METRICS.has(raw.yAxisMetric as string)
+ ? (raw.yAxisMetric as string)
+ : 'y_tpPerGpu';
+
+ const targetInteractivity =
+ typeof raw.targetInteractivity === 'number' &&
+ raw.targetInteractivity > 0 &&
+ raw.targetInteractivity < 1000
+ ? raw.targetInteractivity
+ : 40;
+
+ return {
+ chartType,
+ dataSource,
+ model,
+ sequence,
+ precisions,
+ hardwareKeys,
+ yAxisMetric,
+ yAxisLabel: typeof raw.yAxisLabel === 'string' ? raw.yAxisLabel.slice(0, 100) : yAxisMetric,
+ targetInteractivity,
+ title: typeof raw.title === 'string' ? raw.title.slice(0, 200) : 'AI Generated Chart',
+ description: typeof raw.description === 'string' ? raw.description.slice(0, 500) : '',
+ };
+}
diff --git a/packages/app/src/hooks/api/use-ai-chart.ts b/packages/app/src/hooks/api/use-ai-chart.ts
index 3049889..3e66854 100644
--- a/packages/app/src/hooks/api/use-ai-chart.ts
+++ b/packages/app/src/hooks/api/use-ai-chart.ts
@@ -3,6 +3,7 @@
import { useCallback, useState } from 'react';
import type { AiChartBarPoint, AiChartSpec, AiProvider } from '@/components/ai-chart/types';
+import { validateSpec } from '@/components/ai-chart/types';
import { buildParsePrompt, buildSummaryPrompt } from '@/components/ai-chart/prompt-templates';
import type { InferenceData } from '@/components/inference/types';
import { callLlm } from '@/lib/ai-providers';
@@ -20,11 +21,19 @@ import { getHardwareConfig, getModelSortIndex } from '@/lib/constants';
import chartDefinitions from '@/components/inference/inference-chart-config.json';
-interface AiChartResult {
+// ---------------------------------------------------------------------------
+// Result types
+// ---------------------------------------------------------------------------
+
+export interface AiSingleChartResult {
spec: AiChartSpec;
barData: AiChartBarPoint[];
scatterData: InferenceData[];
colorMap: Record;
+}
+
+export interface AiChartResult {
+ charts: AiSingleChartResult[];
summary: string | null;
}
@@ -36,12 +45,19 @@ interface UseAiChartReturn {
reset: () => void;
}
-function parseSpecFromLlm(raw: string): AiChartSpec {
+// ---------------------------------------------------------------------------
+// LLM response parsing
+// ---------------------------------------------------------------------------
+
+function parseSpecsFromLlm(raw: string): AiChartSpec[] {
const cleaned = raw
.replace(/```json\s*/g, '')
.replace(/```/g, '')
.trim();
- return JSON.parse(cleaned);
+ const parsed = JSON.parse(cleaned);
+ const arr = Array.isArray(parsed) ? parsed : [parsed];
+ // Validate each spec and limit to 2
+ return arr.slice(0, 2).map((s: unknown) => validateSpec(s as Record));
}
// ---------------------------------------------------------------------------
@@ -107,7 +123,6 @@ function buildEvalBarData(
spec: AiChartSpec,
colorMap: Record,
): AiChartBarPoint[] {
- // Filter by model, hardware, precision
let filtered = rows.filter((r) => r.model === spec.model || spec.model === '');
if (spec.hardwareKeys.length > 0) {
const allowed = new Set(spec.hardwareKeys);
@@ -121,7 +136,6 @@ function buildEvalBarData(
filtered = filtered.filter((r) => allowed.has(r.precision.toLowerCase()));
}
- // Group by hardware key, take latest date per group, extract score
const groups = new Map();
for (const row of filtered) {
const hwKey = normalizeEvalHardwareKey(row.hardware, row.framework, row.spec_method);
@@ -133,7 +147,6 @@ function buildEvalBarData(
const bars: AiChartBarPoint[] = [];
for (const [hwKey, row] of groups) {
- // GSM8K score is typically in metrics as "gsm8k" or first metric value
const score = row.metrics.gsm8k ?? row.metrics.accuracy ?? Object.values(row.metrics)[0] ?? 0;
if (score <= 0) continue;
@@ -159,7 +172,6 @@ function buildReliabilityBarData(
spec: AiChartSpec,
colorMap: Record,
): AiChartBarPoint[] {
- // Filter by hardware
let filtered = rows;
if (spec.hardwareKeys.length > 0) {
const allowed = new Set(spec.hardwareKeys);
@@ -169,7 +181,6 @@ function buildReliabilityBarData(
});
}
- // Aggregate across dates: total successes / total attempts per hardware
const agg = new Map();
for (const row of filtered) {
const hw = row.hardware;
@@ -196,6 +207,87 @@ function buildReliabilityBarData(
return bars;
}
+// ---------------------------------------------------------------------------
+// Resolve a single spec into chart data
+// ---------------------------------------------------------------------------
+
+async function resolveSpec(spec: AiChartSpec): Promise {
+ if (spec.dataSource === 'evaluations') {
+ const rows = await fetchEvaluations();
+ const hwKeys = [
+ ...new Set(rows.map((r) => normalizeEvalHardwareKey(r.hardware, r.framework, r.spec_method))),
+ ];
+ const colorMap = generateHighContrastColors(hwKeys, 'dark');
+ const barData = buildEvalBarData(rows, spec, colorMap);
+ // Re-color with final keys
+ const finalKeys = barData.map((b) => b.hwKey);
+ const finalColors = generateHighContrastColors(finalKeys, 'dark');
+ return {
+ spec,
+ barData: barData.map((b) => ({ ...b, color: finalColors[b.hwKey] ?? b.color })),
+ scatterData: [],
+ colorMap: finalColors,
+ };
+ }
+
+ if (spec.dataSource === 'reliability') {
+ const rows = await fetchReliability();
+ const hwKeys = [...new Set(rows.map((r) => r.hardware))];
+ const colorMap = generateHighContrastColors(hwKeys, 'dark');
+ const barData = buildReliabilityBarData(rows, spec, colorMap);
+ const finalKeys = barData.map((b) => b.hwKey);
+ const finalColors = generateHighContrastColors(finalKeys, 'dark');
+ return {
+ spec,
+ barData: barData.map((b) => ({ ...b, color: finalColors[b.hwKey] ?? b.color })),
+ scatterData: [],
+ colorMap: finalColors,
+ };
+ }
+
+ // Benchmarks or History
+ const { isl, osl } = sequenceToIslOsl(spec.sequence);
+ const rows =
+ spec.dataSource === 'history'
+ ? await fetchBenchmarkHistory(spec.model, isl, osl)
+ : await fetchBenchmarks(spec.model);
+
+ const { chartData } = transformBenchmarkRows(rows);
+ let points = chartData[0] ?? [];
+
+ if (spec.hardwareKeys.length > 0) {
+ const allowedGpus = new Set(spec.hardwareKeys);
+ points = points.filter((p) => {
+ const hwKey = p.hwKey ?? '';
+ return allowedGpus.has(hwKey) || [...allowedGpus].some((g) => hwKey.startsWith(g));
+ });
+ }
+ if (spec.precisions.length > 0) {
+ const allowedPrec = new Set(spec.precisions.map((p) => p.toLowerCase()));
+ points = points.filter((p) => p.precision && allowedPrec.has(p.precision.toLowerCase()));
+ }
+
+ if (spec.dataSource !== 'history') {
+ points = points.filter((p) => {
+ const entry = p as any;
+ if (entry.isl != null && entry.osl != null) {
+ return entry.isl === isl && entry.osl === osl;
+ }
+ return true;
+ });
+ }
+
+ const hwKeys = [...new Set(points.map((p) => p.hwKey ?? '').filter(Boolean))];
+ const colorMap = generateHighContrastColors(hwKeys, 'dark');
+
+ return {
+ spec,
+ barData: spec.chartType === 'bar' ? buildBenchmarkBarData(points, spec, colorMap) : [],
+ scatterData: spec.chartType === 'scatter' ? points : [],
+ colorMap,
+ };
+}
+
// ---------------------------------------------------------------------------
// Main hook
// ---------------------------------------------------------------------------
@@ -211,111 +303,54 @@ export function useAiChart(): UseAiChartReturn {
setResult(null);
try {
- // Step 1: Parse prompt into spec
- const rawSpec = await callLlm(provider, apiKey, buildParsePrompt(), prompt);
- const spec = parseSpecFromLlm(rawSpec);
- // Default dataSource for backwards compat
- if (!spec.dataSource) spec.dataSource = 'benchmarks';
-
- let barData: AiChartBarPoint[] = [];
- let scatterData: InferenceData[] = [];
- let hwKeys: string[] = [];
-
- if (spec.dataSource === 'evaluations') {
- // ---- Evaluations ----
- const rows = await fetchEvaluations();
- hwKeys = [
- ...new Set(
- rows.map((r) => normalizeEvalHardwareKey(r.hardware, r.framework, r.spec_method)),
- ),
- ];
- const colorMap = generateHighContrastColors(hwKeys, 'dark');
- barData = buildEvalBarData(rows, spec, colorMap);
-
- if (barData.length === 0) {
- setError('No evaluation data found for the requested configuration.');
- setIsLoading(false);
- return;
- }
-
- hwKeys = barData.map((b) => b.hwKey);
- const finalColorMap = generateHighContrastColors(hwKeys, 'dark');
- barData = barData.map((b) => ({ ...b, color: finalColorMap[b.hwKey] ?? b.color }));
-
- await generateSummary(provider, apiKey, spec, barData, finalColorMap, setResult);
- return;
- }
+ // Step 1: Parse prompt into validated spec(s)
+ const rawResponse = await callLlm(provider, apiKey, buildParsePrompt(), prompt);
+ const specs = parseSpecsFromLlm(rawResponse);
- if (spec.dataSource === 'reliability') {
- // ---- Reliability ----
- const rows = await fetchReliability();
- hwKeys = [...new Set(rows.map((r) => r.hardware))];
- const colorMap = generateHighContrastColors(hwKeys, 'dark');
- barData = buildReliabilityBarData(rows, spec, colorMap);
-
- if (barData.length === 0) {
- setError('No reliability data found for the requested configuration.');
- setIsLoading(false);
- return;
- }
-
- hwKeys = barData.map((b) => b.hwKey);
- const finalColorMap = generateHighContrastColors(hwKeys, 'dark');
- barData = barData.map((b) => ({ ...b, color: finalColorMap[b.hwKey] ?? b.color }));
-
- await generateSummary(provider, apiKey, spec, barData, finalColorMap, setResult);
+ if (specs.length === 0) {
+ setError('Could not parse your request. Try rephrasing.');
+ setIsLoading(false);
return;
}
- // ---- Benchmarks (default) & History ----
- const { isl, osl } = sequenceToIslOsl(spec.sequence);
- const rows =
- spec.dataSource === 'history'
- ? await fetchBenchmarkHistory(spec.model, isl, osl)
- : await fetchBenchmarks(spec.model);
-
- const { chartData } = transformBenchmarkRows(rows);
- let points = chartData[0] ?? [];
-
- // Filter by spec
- if (spec.hardwareKeys.length > 0) {
- const allowedGpus = new Set(spec.hardwareKeys);
- points = points.filter((p) => {
- const hwKey = p.hwKey ?? '';
- return allowedGpus.has(hwKey) || [...allowedGpus].some((g) => hwKey.startsWith(g));
- });
- }
- if (spec.precisions.length > 0) {
- const allowedPrec = new Set(spec.precisions.map((p) => p.toLowerCase()));
- points = points.filter((p) => p.precision && allowedPrec.has(p.precision.toLowerCase()));
- }
-
- // Filter by sequence (for non-history, where all sequences may be returned)
- if (spec.dataSource !== 'history') {
- points = points.filter((p) => {
- const entry = p as any;
- if (entry.isl != null && entry.osl != null) {
- return entry.isl === isl && entry.osl === osl;
- }
- return true;
- });
- }
+ // Step 2: Resolve each spec into chart data (parallel for multi-chart)
+ const charts = await Promise.all(specs.map(resolveSpec));
- if (points.length === 0) {
- setError(
- `No data found for ${spec.model} (${spec.sequence}). Try a different model or configuration.`,
- );
+ // Check if any chart has data
+ const hasData = charts.some((c) => c.barData.length > 0 || c.scatterData.length > 0);
+ if (!hasData) {
+ const models = [...new Set(specs.map((s) => s.model))].join(', ');
+ setError(`No data found for ${models}. Try a different model or configuration.`);
setIsLoading(false);
return;
}
- hwKeys = [...new Set(points.map((p) => p.hwKey ?? '').filter(Boolean))];
- const colorMap = generateHighContrastColors(hwKeys, 'dark');
-
- barData = spec.chartType === 'bar' ? buildBenchmarkBarData(points, spec, colorMap) : [];
- scatterData = spec.chartType === 'scatter' ? points : [];
+ // Step 3: Generate summary (best-effort)
+ let summary: string | null = null;
+ try {
+ const allBars = charts.flatMap((c) => c.barData);
+ const allScatter = charts.flatMap((c) => c.scatterData);
+ const hwKeys = [
+ ...new Set([...allBars.map((b) => b.hwKey), ...allScatter.map((p) => p.hwKey ?? '')]),
+ ].filter(Boolean);
+
+ const dataDesc =
+ allBars.length > 0
+ ? allBars.map((b) => `${b.label}: ${b.value.toFixed(2)}`).join('\n')
+ : `${allScatter.length} data points across ${hwKeys.length} hardware configs`;
+
+ const summaryRaw = await callLlm(
+ provider,
+ apiKey,
+ buildSummaryPrompt(specs, dataDesc),
+ 'Provide the summary.',
+ );
+ summary = summaryRaw.trim();
+ } catch {
+ // Summary generation is non-critical
+ }
- await generateSummary(provider, apiKey, spec, barData, colorMap, setResult, scatterData);
+ setResult({ charts, summary });
} catch (err) {
setError(err instanceof Error ? err.message : 'An unexpected error occurred.');
} finally {
@@ -330,42 +365,3 @@ export function useAiChart(): UseAiChartReturn {
return { result, isLoading, error, generate, reset };
}
-
-async function generateSummary(
- provider: AiProvider,
- apiKey: string,
- spec: AiChartSpec,
- barData: AiChartBarPoint[],
- colorMap: Record,
- setResult: (r: AiChartResult) => void,
- scatterData: InferenceData[] = [],
-) {
- let summary: string | null = null;
- try {
- const hwKeys = [
- ...new Set([...barData.map((b) => b.hwKey), ...scatterData.map((p) => p.hwKey ?? '')]),
- ].filter(Boolean);
- const dataDesc =
- barData.length > 0
- ? barData.map((b) => `${b.label}: ${b.value.toFixed(2)}`).join('\n')
- : `${scatterData.length} data points across ${hwKeys.length} hardware configs`;
-
- const summaryRaw = await callLlm(
- provider,
- apiKey,
- buildSummaryPrompt(spec, dataDesc),
- 'Provide the summary.',
- );
- summary = summaryRaw.trim();
- } catch {
- // Summary generation is non-critical
- }
-
- setResult({
- spec,
- barData,
- scatterData,
- colorMap,
- summary,
- });
-}
diff --git a/packages/app/src/lib/ai-providers.ts b/packages/app/src/lib/ai-providers.ts
index 1c0de60..335b895 100644
--- a/packages/app/src/lib/ai-providers.ts
+++ b/packages/app/src/lib/ai-providers.ts
@@ -132,8 +132,13 @@ export async function callLlm(
const json = await res.json();
if (!res.ok) {
- const msg =
+ const raw =
json?.error?.message ?? json?.error?.type ?? `${provider} request failed (${res.status})`;
+ // Strip anything that looks like an API key to prevent accidental leaks in UI
+ const msg = String(raw)
+ .replace(/sk-[a-zA-Z0-9_-]{10,}/g, '[REDACTED]')
+ .replace(/key-[a-zA-Z0-9_-]{10,}/g, '[REDACTED]')
+ .replace(/Bearer\s+\S+/gi, 'Bearer [REDACTED]');
throw new Error(msg);
}