diff --git a/src/components/layout/LandingPage.tsx b/src/components/layout/LandingPage.tsx index adffcf7..fa77e7e 100644 --- a/src/components/layout/LandingPage.tsx +++ b/src/components/layout/LandingPage.tsx @@ -12,6 +12,7 @@ function LandingPage() { MoE Visualizer diff --git a/src/components/layout/VisualizerPage.module.css b/src/components/layout/VisualizerPage.module.css index e9bb2dd..4229faa 100644 --- a/src/components/layout/VisualizerPage.module.css +++ b/src/components/layout/VisualizerPage.module.css @@ -48,6 +48,7 @@ .nav { display: flex; gap: var(--spacing-lg); + align-items: center; } .nav a { @@ -61,6 +62,30 @@ color: var(--color-text); } +.metricsButton { + padding: var(--spacing-xs) var(--spacing-md); + background: linear-gradient(135deg, var(--color-primary), var(--color-secondary)); + color: white; + font-size: 0.875rem; + font-weight: 600; + border: none; + border-radius: var(--radius-md); + cursor: pointer; + transition: all var(--transition-fast); + display: flex; + align-items: center; + gap: var(--spacing-xs); +} + +.metricsButton:hover { + transform: translateY(-2px); + box-shadow: var(--shadow-glow); +} + +.metricsButton:active { + transform: translateY(0); +} + /* Main Content */ .main { flex: 1; diff --git a/src/components/layout/VisualizerPage.tsx b/src/components/layout/VisualizerPage.tsx index f647d13..4ad1198 100644 --- a/src/components/layout/VisualizerPage.tsx +++ b/src/components/layout/VisualizerPage.tsx @@ -1,12 +1,17 @@ +import { useState } from 'react' import { Link } from 'react-router-dom' import { useMoeStore } from '../../store/moeStore' import { useSimulationStore } from '../../store/simulationStore' import ExpertNetwork from '../visualizers/ExpertNetwork' import AnimationPanel from '../visualizers/AnimationPanel' import StatusLegend from '../common/StatusLegend' +import { MetricsPanel } from '../visualizers/MetricsPanel' import styles from './VisualizerPage.module.css' function VisualizerPage() { + // Metrics panel state + const [isMetricsPanelOpen, setIsMetricsPanelOpen] = useState(false) + // Get values and setters from the store const numExperts = useMoeStore(state => state.numExperts) const topK = useMoeStore(state => state.topK) @@ -33,6 +38,13 @@ function VisualizerPage() { @@ -117,6 +129,12 @@ function VisualizerPage() { + + {/* Metrics Sidebar */} + setIsMetricsPanelOpen(false)} + /> ) } diff --git a/src/components/visualizers/AnimationPanel.tsx b/src/components/visualizers/AnimationPanel.tsx index 7299e1f..fa740b2 100644 --- a/src/components/visualizers/AnimationPanel.tsx +++ b/src/components/visualizers/AnimationPanel.tsx @@ -11,7 +11,7 @@ function AnimationPanel() { const topK = useMoeStore(state => state.topK) const addToken = useSimulationStore(state => state.addToken) - const MAX_TOKENS = 20 + const MAX_TOKENS = 50 const [input, setInput] = useState('') const setAnimationState = useSimulationStore(state => state.setAnimationState) diff --git a/src/components/visualizers/ExpertNetwork.tsx b/src/components/visualizers/ExpertNetwork.tsx index 6466d6b..43d55e6 100644 --- a/src/components/visualizers/ExpertNetwork.tsx +++ b/src/components/visualizers/ExpertNetwork.tsx @@ -46,7 +46,7 @@ function ExpertNetwork() { if (!expert) return null const weight = token.routingWeights[index] - const strokeWidth = 1 + weight * 3 // 1-4px based on weight + const strokeWidth = 1 + weight * 2 // Add curve based on index to separate overlapping lines const curveOffset = (index - (token.targetExperts.length - 1) / 2) * 15 diff --git a/src/components/visualizers/MetricsPanel.module.css b/src/components/visualizers/MetricsPanel.module.css new file mode 100644 index 0000000..944aa17 --- /dev/null +++ b/src/components/visualizers/MetricsPanel.module.css @@ -0,0 +1,212 @@ +.backdrop { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + backdrop-filter: blur(4px); + z-index: 1000; + animation: fadeIn 0.3s ease-out; +} + +@keyframes fadeIn { + from { + opacity: 0; + } + to { + opacity: 1; + } +} + +.panel { + position: fixed; + top: 0; + right: 0; + bottom: 0; + width: 400px; + max-width: 90vw; + background: var(--color-surface); + border-left: 1px solid var(--color-surface-light); + box-shadow: var(--shadow-xl); + z-index: 1001; + display: flex; + flex-direction: column; + transform: translateX(100%); + transition: transform 0.3s ease-out; +} + +.panel.open { + transform: translateX(0); +} + +.header { + display: flex; + justify-content: space-between; + align-items: center; + padding: var(--spacing-lg); + border-bottom: 1px solid var(--color-surface-light); + background: var(--color-background); +} + +.header h2 { + margin: 0; + font-size: 1.5rem; + color: var(--color-text); + font-weight: 600; +} + +.closeButton { + width: 32px; + height: 32px; + border-radius: var(--radius-sm); + border: none; + background: var(--color-surface-light); + color: var(--color-text); + font-size: 1.5rem; + line-height: 1; + cursor: pointer; + transition: all var(--transition-fast); + display: flex; + align-items: center; + justify-content: center; + flex-shrink: 0; +} + +.closeButton:hover { + background: var(--color-surface); + color: var(--color-primary); + transform: scale(1.1); +} + +.content { + flex: 1; + padding: var(--spacing-xl) var(--spacing-lg); + overflow-y: auto; +} + +/* Sections */ +.section { + margin-bottom: var(--spacing-2xl); +} + +.sectionTitle { + font-size: 1.125rem; + font-weight: 600; + color: var(--color-text); + margin: 0 0 var(--spacing-md) 0; + padding-bottom: var(--spacing-xs); + border-bottom: 2px solid var(--color-surface-light); +} + +/* Metric Cards */ +.metricCard { + background: var(--color-background); + border: 1px solid var(--color-surface-light); + border-radius: var(--radius-md); + padding: var(--spacing-md); + margin-bottom: var(--spacing-md); +} + +.metricHeader { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: var(--spacing-xs); +} + +.metricLabel { + font-size: 1.2rem; + font-weight: 600; + color: var(--color-text-secondary); +} + +.metricValue { + font-size: 1.25rem; + font-weight: 700; + color: var(--color-text); + font-family: monospace; +} + +.metricValueBadge { + font-size: 1.25rem; + font-weight: 700; + color: white; + font-family: monospace; + padding: var(--spacing-xs) var(--spacing-md); + border-radius: var(--radius-md); + transition: background-color 0.3s ease; +} + +.metricDescription { + font-size: 0.8rem; + color: var(--color-text-secondary); + font-style: italic; +} + +.metricDescription sub { + font-size: 0.65rem; +} + +/* Expert Utilization Bars */ +.expertBars { + display: flex; + flex-direction: column; + gap: var(--spacing-sm); +} + +.expertBar { + display: flex; + flex-direction: column; + gap: 4px; +} + +.expertBarHeader { + display: flex; + align-items: center; + gap: var(--spacing-xs); + font-size: 0.875rem; +} + +.expertDot { + width: 10px; + height: 10px; + border-radius: 50%; + flex-shrink: 0; +} + +.expertName { + color: var(--color-text-secondary); + font-weight: 500; + flex: 1; +} + +.expertValue { + color: var(--color-text); + font-weight: 700; + font-family: monospace; + font-size: 0.8rem; +} + +.expertBarTrack { + width: 100%; + height: 6px; + background: var(--color-surface-light); + border-radius: 3px; + overflow: hidden; +} + +.expertBarFill { + height: 100%; + transition: width 0.3s ease; + border-radius: 3px; +} + +/* Mobile responsiveness */ +@media (max-width: 768px) { + .panel { + width: 100%; + max-width: 100vw; + } +} + diff --git a/src/components/visualizers/MetricsPanel.tsx b/src/components/visualizers/MetricsPanel.tsx new file mode 100644 index 0000000..0252eff --- /dev/null +++ b/src/components/visualizers/MetricsPanel.tsx @@ -0,0 +1,121 @@ +import React from 'react'; +import { useSimulationStore } from '../../store/simulationStore'; +import styles from './MetricsPanel.module.css'; + +interface MetricsPanelProps { + isOpen: boolean; + onClose: () => void; +} + +export const MetricsPanel: React.FC = ({ isOpen, onClose }) => { + const stats = useSimulationStore(state => state.stats); + const experts = useSimulationStore(state => state.experts); + + const getAuxLossStatus = (value: number) => { + if (value < 1.2) return { color: '#10b981', label: 'Excellent' }; + if (value < 1.5) return { color: '#f59e0b', label: 'Moderate' }; + return { color: '#ef4444', label: 'Poor' }; + }; + + const getImbalanceStatus = (value: number) => { + if (value < 0.3) return { color: '#10b981', label: 'Excellent' }; + if (value < 0.6) return { color: '#f59e0b', label: 'Moderate' }; + return { color: '#ef4444', label: 'Poor' }; + }; + + return ( + <> + {/* Backdrop */} + {isOpen && ( +
+ )} + + {/* Sidebar Panel */} +
+
+

Metrics

+ +
+ +
+ {/* Load Balancing Section */} +
+

Load Balancing

+ + {/* Auxiliary Loss */} +
+
+ Auxiliary Loss + + {stats.auxiliaryLoss.toFixed(4)} + +
+
+ N × Σ(fe × Pe) - Routing-dispatch mismatch. Ideal = 1.0 +
+
+ + {/* Load Imbalance Factor */} +
+
+ Load Imbalance Factor + + {stats.loadImbalanceFactor.toFixed(3)} + +
+
+ Coefficient of Variation (CV) = σ / μ. Ideal = 0.0 +
+
+
+ + {/* Expert Utilization Section */} +
+

Expert Utilization %

+
+ {experts.map((expert, idx) => ( +
+
+ + E{expert.id} + + {stats.expertUtilization[idx]?.toFixed(1) || 0}% + +
+
+
+
+
+ ))} +
+
+
+
+ + ); +}; + diff --git a/src/store/simulationStore.ts b/src/store/simulationStore.ts index 4c7df35..62f5716 100644 --- a/src/store/simulationStore.ts +++ b/src/store/simulationStore.ts @@ -62,6 +62,10 @@ const initialStats: MoeStats = { maxExpertLoad: 0, minExpertLoad: 0, isBalanced: true, + auxiliaryLoss: 0, + loadImbalanceFactor: 0, + expertUtilization: [], + tokensPerExpert: [], } const initialAnimationState: AnimationState = { @@ -131,7 +135,7 @@ export const useSimulationStore = create((set, get) => ({ const existingPositions = tokens.map(t => t.position) const centerX = 450 const centerY = 325 - const maxRadius = Math.min(300, 100 + tokens.length * 5) + const maxRadius = Math.min(250, 100 + tokens.length * 5) const minDistance = 32 // Reduced from 40 to fit smaller circles let position = { x: centerX, y: centerY } @@ -178,11 +182,15 @@ export const useSimulationStore = create((set, get) => ({ expertId, weight: newToken.routingWeights[index], timestamp: Date.now(), + gatingProbabilities: newToken.gatingProbabilities, }) }) set({ tokens: [...tokens, newToken] }) + // Update stats after adding token + get().updateStats() + const routingDelay = 800 // Time for lines to draw setTimeout(() => { // Check if token still exists @@ -277,20 +285,27 @@ export const useSimulationStore = create((set, get) => ({ const { tokens } = get() const batchTokens = tokens.filter( - t => t.status === 'processing' && t.targetExperts.includes(expertId) + t => (t.status === 'routing' || t.status === 'processing') && + t.targetExperts.includes(expertId) ) - // Mark all tokens as complete - batchTokens.forEach(token => { - get().updateToken(token.id, { status: 'complete', ffnStage: 'output' }) - }) - - // Deactivate expert + const currentLoad = get().experts.find(e => e.id === expertId)!.loadCount get().updateExpert(expertId, { isActive: false, batchStartTime: null, batchProcessingTime: null, - loadCount: get().experts.find(e => e.id === expertId)!.loadCount + batchTokens.length + loadCount: currentLoad + batchTokens.length + }) + + batchTokens.forEach(token => { + const allExpertsDone = token.targetExperts.every(expId => { + const expert = get().experts.find(e => e.id === expId) + return expert && (!expert.isActive || expId === expertId) + }) + + if (allExpertsDone) { + get().updateToken(token.id, { status: 'complete', ffnStage: 'output' }) + } }) setTimeout(() => { @@ -338,6 +353,47 @@ export const useSimulationStore = create((set, get) => ({ expert => Math.abs(expert.loadCount - avgLoad) <= avgLoad * 0.2 ) + // Calculate auxiliary loss: L_aux = N * Σ(f_e * P_e) + // where: + // - N = number of experts + // - f_e = fraction of tokens dispatched to expert e + // - P_e = average gating probability for expert e + let auxiliaryLoss = 0 + if (totalLoad > 0 && routingHistory.length > 0) { + const avgGatingProbs = new Array(experts.length).fill(0) + const tokenGatingProbs = new Map() + routingHistory.forEach(decision => { + if (!tokenGatingProbs.has(decision.tokenId) && decision.gatingProbabilities) { + tokenGatingProbs.set(decision.tokenId, decision.gatingProbabilities) + } + }) + tokenGatingProbs.forEach(probs => { + probs.forEach((prob, expertIdx) => { + avgGatingProbs[expertIdx] += prob + }) + }) + const numTokens = tokenGatingProbs.size + avgGatingProbs.forEach((sum, idx) => { + avgGatingProbs[idx] = sum / numTokens + }) + + const dispatchFractions = loads.map(load => load / totalLoad) + const dotProduct = dispatchFractions.reduce((sum, f_e, expertIdx) => { + const P_e = avgGatingProbs[expertIdx] + return sum + (f_e * P_e) + }, 0) + auxiliaryLoss = experts.length * dotProduct + } + + // CV = (standard deviation) / mean + const variance = loads.reduce((sum, load) => sum + Math.pow(load - avgLoad, 2), 0) / experts.length + const stdDev = Math.sqrt(variance) + const loadImbalanceFactor = avgLoad > 0 ? stdDev / avgLoad : 0 + const expertUtilization = totalLoad > 0 + ? loads.map(load => (load / totalLoad) * 100) + : loads.map(() => 0) + const tokensPerExpert = [...loads] + set({ stats: { totalTokensProcessed: routingHistory.length, @@ -345,6 +401,10 @@ export const useSimulationStore = create((set, get) => ({ maxExpertLoad: maxLoad, minExpertLoad: minLoad, isBalanced, + auxiliaryLoss, + loadImbalanceFactor, + expertUtilization, + tokensPerExpert, }, }) }, diff --git a/src/types/moe.types.ts b/src/types/moe.types.ts index 6f7f45b..04f9e1a 100644 --- a/src/types/moe.types.ts +++ b/src/types/moe.types.ts @@ -8,13 +8,13 @@ export interface Position { export interface Expert { id: number name: string - specialization: string // e.g., "Grammar", "Noun", "Verb" - color: string // Hex color for visualization - position: Position // Where to draw it - loadCount: number // How many tokens this expert has processed - isActive: boolean // Currently processing? - batchStartTime: number | null // When the current batch started processing - batchProcessingTime: number | null // How long this batch will take (ms) + specialization: string + color: string + position: Position + loadCount: number + isActive: boolean + batchStartTime: number | null + batchProcessingTime: number | null } // Status of a token as it moves through the system @@ -26,22 +26,24 @@ export type FFNStage = 'input' | 'ffn1' | 'relu' | 'ffn2' | 'output' | null // Represents a token (input) being processed export interface Token { id: string - content: string // What the token represents - position: Position // Current position for animation - targetExperts: number[] // IDs of experts this token routes to - routingWeights: number[] // Weight for each target expert (sums to 1.0) + content: string + position: Position + targetExperts: number[] + routingWeights: number[] + gatingProbabilities: number[] // Softmax probabilities for ALL experts (for aux loss) status: TokenStatus - timestamp: number // When it was created - ffnStage?: FFNStage // Current FFN processing stage (if processing) - processingExpertId?: number // Which expert is currently processing this token + timestamp: number + ffnStage?: FFNStage + processingExpertId?: number } // A routing decision made by the gating network export interface RoutingDecision { tokenId: string expertId: number - weight: number // 0.0 - 1.0 + weight: number timestamp: number + gatingProbabilities: number[] } // Statistics about the MoE system @@ -50,7 +52,11 @@ export interface MoeStats { avgExpertUtilization: number maxExpertLoad: number minExpertLoad: number - isBalanced: boolean // Whether load is evenly distributed + isBalanced: boolean + auxiliaryLoss: number + loadImbalanceFactor: number + expertUtilization: number[] + tokensPerExpert: number[] } // Animation step for educational visualization @@ -65,9 +71,9 @@ export type AnimationStep = export interface AnimationState { currentStep: AnimationStep currentTokenIndex: number - expertScores: number[] // Scores for first token (histogram display) - selectedExperts: number[] // Top-K expert IDs for first token (histogram display) - allSelectedExperts: number[] // All unique experts across all tokens (main view) + expertScores: number[] + selectedExperts: number[] + allSelectedExperts: number[] isPlaying: boolean } diff --git a/src/utils/moeInitialization.ts b/src/utils/moeInitialization.ts index d3f6bd0..5e864ad 100644 --- a/src/utils/moeInitialization.ts +++ b/src/utils/moeInitialization.ts @@ -41,6 +41,7 @@ export function generateToken(id: string): Token { position: { x: 450, y: 325 }, // Start at center targetExperts: [], routingWeights: [], + gatingProbabilities: [], // Will be populated during routing status: 'idle', timestamp: Date.now(), } diff --git a/src/utils/routing.ts b/src/utils/routing.ts index 0690630..4228411 100644 --- a/src/utils/routing.ts +++ b/src/utils/routing.ts @@ -38,6 +38,17 @@ export function selectTopK(scores: number[], k: number): number[] { return topK.map(pair => pair.index) } +/** + * Compute softmax probabilities from raw scores + * This gives us the gating probabilities (P) for auxiliary loss + */ +export function softmax(scores: number[]): number[] { + const maxScore = Math.max(...scores) + const expScores = scores.map(score => Math.exp(score - maxScore)) + const sumExp = expScores.reduce((acc, val) => acc + val, 0) + return expScores.map(exp => exp / sumExp) +} + /** * Normalize weights so they sum to 1.0 */ @@ -62,13 +73,16 @@ export function normalizeWeights(scores: number[], selectedIndices: number[]): n * Updates the token with target experts and weights */ export function routeToken(token: Token, experts: Expert[], topK: number): Token { - // Compute scores + // Compute raw scores const scores = computeGatingScores(token, experts) + // Compute softmax probabilities (P_e in auxiliary loss formula) + const gatingProbabilities = softmax(scores) + // Select top K experts const targetExperts = selectTopK(scores, topK) - // Normalize weights + // Normalize weights for selected experts const routingWeights = normalizeWeights(scores, targetExperts) // Update token @@ -76,6 +90,7 @@ export function routeToken(token: Token, experts: Expert[], topK: number): Token ...token, targetExperts, routingWeights, + gatingProbabilities, // Store full softmax for auxiliary loss calculation status: 'routing', } }