Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/components/visualizers/AnimationPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -187,25 +187,27 @@ function AnimationPanel() {
{/* Routing Histogram - shows expert scores and top-K selection */}
<div className={styles.routingHistogram}>
<div className={styles.histogramContent}>
<h4 className={styles.histogramTitle}>
{animationState.expertScores.length > 0
? `Router Scores for "${inputTokens[0] || 'token'}" (Softmax)`
: 'Router Scores (Softmax)'}
{animationState.currentStep === 'selecting' && animationState.expertScores.length > 0 && ` - Top-${topK} Selected`}
</h4>
<h4 className={styles.histogramTitle}>
{animationState.expertScores.length > 0
? <>Router Scores for <strong>{inputTokens[0] || 'token'}</strong></>
: 'Router Scores'}
{animationState.currentStep === 'selecting' && animationState.expertScores.length > 0 && ` - Top-${topK} Selected`}
</h4>
<div className={styles.barsContainer}>
{experts.map((expert, idx) => {
const score = animationState.expertScores[idx] || 0
const isSelected = animationState.selectedExperts.includes(idx)
const showSelection = ['selecting', 'routing', 'complete'].includes(animationState.currentStep)
const maxScore = Math.max(...animationState.expertScores, 0.01)
const scaledHeight = (score / maxScore) * 100

return (
<div key={expert.id} className={styles.barColumn}>
<div className={styles.barWrapper}>
<div
className={`${styles.bar} ${showSelection && isSelected ? styles.barSelected : ''}`}
style={{
height: `${score * 100}%`,
height: scaledHeight,
backgroundColor: expert.color,
opacity: showSelection && !isSelected ? 0.3 : 0.8
}}
Expand Down
130 changes: 85 additions & 45 deletions src/store/simulationStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ interface SimulationStore {
// Actions - Expert management
updateExpert: (expertId: number, updates: Partial<Expert>) => void
incrementExpertLoad: (expertId: number) => void
completeExpertBatch: (expertId: number) => void

// Actions - Playback control
play: () => void
Expand Down Expand Up @@ -101,6 +102,8 @@ export const useSimulationStore = create<SimulationStore>((set, get) => ({
...expert,
loadCount: 0,
isActive: false,
batchStartTime: null,
batchProcessingTime: null,
}))
set({
experts: resetExperts,
Expand Down Expand Up @@ -166,7 +169,8 @@ export const useSimulationStore = create<SimulationStore>((set, get) => ({
// Route the token to experts
const topK = useMoeStore.getState().topK
newToken = routeToken(newToken, experts, topK)

newToken.status = 'routing'

// Record routing decisions
newToken.targetExperts.forEach((expertId, index) => {
get().recordRouting({
Expand All @@ -175,66 +179,63 @@ export const useSimulationStore = create<SimulationStore>((set, get) => ({
weight: newToken.routingWeights[index],
timestamp: Date.now(),
})

// Mark expert as active
get().updateExpert(expertId, { isActive: true })
})

set({ tokens: [...tokens, newToken] })

// Calculate processing time based on expert load (realistic!)
// Base time: 3 seconds
// + 0.5s for each token already on the same experts
// + random jitter (±10%)
const expertLoads = newToken.targetExperts.map(expertId => {
// Count how many other tokens are currently processing on this expert
const tokensOnExpert = get().tokens.filter(
t => t.status === 'processing' && t.targetExperts.includes(expertId)
).length
return tokensOnExpert
})

// Use the maximum load across all target experts
const maxLoad = Math.max(...expertLoads, 0)
const baseProcessingTime = 3000 // 3 seconds
const loadPenalty = maxLoad * 500 // +0.5s per token on same expert
const jitter = (Math.random() - 0.5) * 0.2 // ±10% random variance
const processingTime = Math.round((baseProcessingTime + loadPenalty) * (1 + jitter))

// Auto-progress token through states
const routingDelay = 800 // Time for lines to draw
setTimeout(() => {
// Check if token still exists
const token = get().tokens.find(t => t.id === tokenId)
if (!token) return

get().updateToken(tokenId, {
status: 'processing',
ffnStage: 'input'
})

// Progress through FFN stages once, then stop at output
token.targetExperts.forEach((expertId) => {
const expert = get().experts.find(e => e.id === expertId)
if (expert && !expert.isActive) {
const batchSize = get().tokens.filter(t =>
(t.status === 'routing' || t.status === 'processing') &&
t.targetExperts.includes(expertId)
).length

// Processing time scales significantly with batch size
// Base: 2 seconds, +1.5s per additional token
// 1 token = 2s, 2 tokens = 7s, 3 tokens = 12s, 4 tokens = 17s
const baseTime = 2000
const timePerToken = 500
const processingTime = baseTime + (batchSize - 1) * timePerToken

get().updateExpert(expertId, {
isActive: true,
batchStartTime: Date.now(),
batchProcessingTime: processingTime
})

setTimeout(() => {
get().completeExpertBatch(expertId)
}, processingTime)
}
})

// Progress token through FFN stages for visualization
const avgProcessingTime = 3000
const ffnStages: Array<'ffn1' | 'relu' | 'ffn2' | 'output'> =
['ffn1', 'relu', 'ffn2', 'output']
const stageTime = processingTime / (ffnStages.length + 1)
const stageTime = avgProcessingTime / (ffnStages.length + 1)

ffnStages.forEach((stage, index) => {
setTimeout(() => {
get().updateToken(tokenId, { ffnStage: stage })
const currentToken = get().tokens.find(t => t.id === tokenId)
if (currentToken && currentToken.status === 'processing') {
get().updateToken(tokenId, { ffnStage: stage })
}
}, stageTime * (index + 1))
})

// After processing, mark as complete and remove
setTimeout(() => {
get().updateToken(tokenId, { status: 'complete', ffnStage: 'output' })

// Increment expert load and deactivate
newToken.targetExperts.forEach(expertId => {
get().incrementExpertLoad(expertId)
get().updateExpert(expertId, { isActive: false })
})

// Remove token after brief delay
setTimeout(() => {
get().removeToken(tokenId)
get().updateStats()
}, 1000) // Show complete state for 1 second
}, processingTime)
}, 100)
}, routingDelay)
},

// Update a specific token
Expand Down Expand Up @@ -271,6 +272,45 @@ export const useSimulationStore = create<SimulationStore>((set, get) => ({
}))
},

// Complete all tokens in an expert's batch
completeExpertBatch: expertId => {
const { tokens } = get()

const batchTokens = tokens.filter(
t => t.status === 'processing' && t.targetExperts.includes(expertId)
)

// Mark all tokens as complete
batchTokens.forEach(token => {
get().updateToken(token.id, { status: 'complete', ffnStage: 'output' })
})

// Deactivate expert
get().updateExpert(expertId, {
isActive: false,
batchStartTime: null,
batchProcessingTime: null,
loadCount: get().experts.find(e => e.id === expertId)!.loadCount + batchTokens.length
})

setTimeout(() => {
const completeTokens = get().tokens.filter(t => t.status === 'complete')

completeTokens.forEach(token => {
// Check if ALL of this token's experts are done processing
const allExpertsDone = token.targetExperts.every(expId => {
const expert = get().experts.find(e => e.id === expId)
return expert?.isActive === false
})

if (allExpertsDone) {
get().removeToken(token.id)
}
})
get().updateStats()
}, 1000)
},

// Playback controls
play: () => set({ isPlaying: true }),
pause: () => set({ isPlaying: false }),
Expand Down
4 changes: 3 additions & 1 deletion src/types/moe.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ export interface Position {
export interface Expert {
id: number
name: string
specialization: string // e.g., "Math", "Language", "Science"
specialization: string // e.g., "Grammar", "Noun", "Verb"
color: string // Hex color for visualization
position: Position // Where to draw it
loadCount: number // How many tokens this expert has processed
isActive: boolean // Currently processing?
batchStartTime: number | null // When the current batch started processing
batchProcessingTime: number | null // How long this batch will take (ms)
}

// Status of a token as it moves through the system
Expand Down
2 changes: 2 additions & 0 deletions src/utils/moeInitialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export function initializeExperts(count: number): Expert[] {
position: { x, y },
loadCount: 0,
isActive: false,
batchStartTime: null,
batchProcessingTime: null,
})
}

Expand Down