diff --git a/src/components/visualizers/AnimationPanel.tsx b/src/components/visualizers/AnimationPanel.tsx index fb7f4d9..6538d80 100644 --- a/src/components/visualizers/AnimationPanel.tsx +++ b/src/components/visualizers/AnimationPanel.tsx @@ -187,17 +187,19 @@ function AnimationPanel() { {/* Routing Histogram - shows expert scores and top-K selection */}
-

- {animationState.expertScores.length > 0 - ? `Router Scores for "${inputTokens[0] || 'token'}" (Softmax)` - : 'Router Scores (Softmax)'} - {animationState.currentStep === 'selecting' && animationState.expertScores.length > 0 && ` - Top-${topK} Selected`} -

+

+ {animationState.expertScores.length > 0 + ? <>Router Scores for {inputTokens[0] || 'token'} + : 'Router Scores'} + {animationState.currentStep === 'selecting' && animationState.expertScores.length > 0 && ` - Top-${topK} Selected`} +

{experts.map((expert, idx) => { const score = animationState.expertScores[idx] || 0 const isSelected = animationState.selectedExperts.includes(idx) const showSelection = ['selecting', 'routing', 'complete'].includes(animationState.currentStep) + const maxScore = Math.max(...animationState.expertScores, 0.01) + const scaledHeight = (score / maxScore) * 100 return (
@@ -205,7 +207,7 @@ function AnimationPanel() {
) => void incrementExpertLoad: (expertId: number) => void + completeExpertBatch: (expertId: number) => void // Actions - Playback control play: () => void @@ -101,6 +102,8 @@ export const useSimulationStore = create((set, get) => ({ ...expert, loadCount: 0, isActive: false, + batchStartTime: null, + batchProcessingTime: null, })) set({ experts: resetExperts, @@ -166,7 +169,8 @@ export const useSimulationStore = create((set, get) => ({ // Route the token to experts const topK = useMoeStore.getState().topK newToken = routeToken(newToken, experts, topK) - + newToken.status = 'routing' + // Record routing decisions newToken.targetExperts.forEach((expertId, index) => { get().recordRouting({ @@ -175,66 +179,63 @@ export const useSimulationStore = create((set, get) => ({ weight: newToken.routingWeights[index], timestamp: Date.now(), }) - - // Mark expert as active - get().updateExpert(expertId, { isActive: true }) }) set({ tokens: [...tokens, newToken] }) - // Calculate processing time based on expert load (realistic!) - // Base time: 3 seconds - // + 0.5s for each token already on the same experts - // + random jitter (±10%) - const expertLoads = newToken.targetExperts.map(expertId => { - // Count how many other tokens are currently processing on this expert - const tokensOnExpert = get().tokens.filter( - t => t.status === 'processing' && t.targetExperts.includes(expertId) - ).length - return tokensOnExpert - }) - - // Use the maximum load across all target experts - const maxLoad = Math.max(...expertLoads, 0) - const baseProcessingTime = 3000 // 3 seconds - const loadPenalty = maxLoad * 500 // +0.5s per token on same expert - const jitter = (Math.random() - 0.5) * 0.2 // ±10% random variance - const processingTime = Math.round((baseProcessingTime + loadPenalty) * (1 + jitter)) - - // Auto-progress token through states + const routingDelay = 800 // Time for lines to draw setTimeout(() => { + // Check if token still exists + const token = get().tokens.find(t => t.id === tokenId) + if (!token) return + get().updateToken(tokenId, { status: 'processing', ffnStage: 'input' }) - // Progress through FFN stages once, then stop at output + token.targetExperts.forEach((expertId) => { + const expert = get().experts.find(e => e.id === expertId) + if (expert && !expert.isActive) { + const batchSize = get().tokens.filter(t => + (t.status === 'routing' || t.status === 'processing') && + t.targetExperts.includes(expertId) + ).length + + // Processing time scales significantly with batch size + // Base: 2 seconds, +1.5s per additional token + // 1 token = 2s, 2 tokens = 7s, 3 tokens = 12s, 4 tokens = 17s + const baseTime = 2000 + const timePerToken = 500 + const processingTime = baseTime + (batchSize - 1) * timePerToken + + get().updateExpert(expertId, { + isActive: true, + batchStartTime: Date.now(), + batchProcessingTime: processingTime + }) + + setTimeout(() => { + get().completeExpertBatch(expertId) + }, processingTime) + } + }) + + // Progress token through FFN stages for visualization + const avgProcessingTime = 3000 const ffnStages: Array<'ffn1' | 'relu' | 'ffn2' | 'output'> = ['ffn1', 'relu', 'ffn2', 'output'] - const stageTime = processingTime / (ffnStages.length + 1) + const stageTime = avgProcessingTime / (ffnStages.length + 1) + ffnStages.forEach((stage, index) => { setTimeout(() => { - get().updateToken(tokenId, { ffnStage: stage }) + const currentToken = get().tokens.find(t => t.id === tokenId) + if (currentToken && currentToken.status === 'processing') { + get().updateToken(tokenId, { ffnStage: stage }) + } }, stageTime * (index + 1)) }) - - // After processing, mark as complete and remove - setTimeout(() => { - get().updateToken(tokenId, { status: 'complete', ffnStage: 'output' }) - - // Increment expert load and deactivate - newToken.targetExperts.forEach(expertId => { - get().incrementExpertLoad(expertId) - get().updateExpert(expertId, { isActive: false }) - }) - - // Remove token after brief delay - setTimeout(() => { - get().removeToken(tokenId) - get().updateStats() - }, 1000) // Show complete state for 1 second - }, processingTime) - }, 100) + }, routingDelay) }, // Update a specific token @@ -271,6 +272,45 @@ export const useSimulationStore = create((set, get) => ({ })) }, + // Complete all tokens in an expert's batch + completeExpertBatch: expertId => { + const { tokens } = get() + + const batchTokens = tokens.filter( + t => t.status === 'processing' && t.targetExperts.includes(expertId) + ) + + // Mark all tokens as complete + batchTokens.forEach(token => { + get().updateToken(token.id, { status: 'complete', ffnStage: 'output' }) + }) + + // Deactivate expert + get().updateExpert(expertId, { + isActive: false, + batchStartTime: null, + batchProcessingTime: null, + loadCount: get().experts.find(e => e.id === expertId)!.loadCount + batchTokens.length + }) + + setTimeout(() => { + const completeTokens = get().tokens.filter(t => t.status === 'complete') + + completeTokens.forEach(token => { + // Check if ALL of this token's experts are done processing + const allExpertsDone = token.targetExperts.every(expId => { + const expert = get().experts.find(e => e.id === expId) + return expert?.isActive === false + }) + + if (allExpertsDone) { + get().removeToken(token.id) + } + }) + get().updateStats() + }, 1000) + }, + // Playback controls play: () => set({ isPlaying: true }), pause: () => set({ isPlaying: false }), diff --git a/src/types/moe.types.ts b/src/types/moe.types.ts index 059f16f..6f7f45b 100644 --- a/src/types/moe.types.ts +++ b/src/types/moe.types.ts @@ -8,11 +8,13 @@ export interface Position { export interface Expert { id: number name: string - specialization: string // e.g., "Math", "Language", "Science" + specialization: string // e.g., "Grammar", "Noun", "Verb" color: string // Hex color for visualization position: Position // Where to draw it loadCount: number // How many tokens this expert has processed isActive: boolean // Currently processing? + batchStartTime: number | null // When the current batch started processing + batchProcessingTime: number | null // How long this batch will take (ms) } // Status of a token as it moves through the system diff --git a/src/utils/moeInitialization.ts b/src/utils/moeInitialization.ts index 6c1851f..d3f6bd0 100644 --- a/src/utils/moeInitialization.ts +++ b/src/utils/moeInitialization.ts @@ -23,6 +23,8 @@ export function initializeExperts(count: number): Expert[] { position: { x, y }, loadCount: 0, isActive: false, + batchStartTime: null, + batchProcessingTime: null, }) }