Skip to content

A rigorous 2x3 factorial comparison of neural network architectures: KAN vs MLP feedforward layers combined with Transformer vs Mamba sequence models. Investigates whether KAN advantages stem from B-spline activations or network topology.

License

Notifications You must be signed in to change notification settings

stchakwdev/Mamba_KAN

Repository files navigation

Mamba-KAN

A Rigorous Factorial Comparison of Neural Network Architectures

Python 3.9+ PyTorch 2.0+ License: MIT CI Code style: black

Documentation | Quick Start | Results | Citation


Overview

This project implements a comprehensive 2×3 factorial experiment comparing neural network architectures, investigating the interplay between feedforward components and sequence modeling approaches.

Research Question

Do Kolmogorov-Arnold Networks (KAN) outperform MLPs due to their learnable B-spline activation functions, or their unique network topology?

Following Wu et al. (2024), we isolate these effects by including MLP+B-spline baselines alongside Transformer and Mamba sequence models.


Architecture

┌─────────────────────────────────────────────────────────────────┐
│                    2×3 FACTORIAL DESIGN                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│   Feedforward Type          Sequence Model                      │
│   ════════════════         ══════════════                       │
│                                                                 │
│   ┌─────────────┐          ┌─────────────┐                      │
│   │  MLP        │──────────│ Transformer │ ─► mlp_transformer   │
│   │  (ReLU/GELU)│          │ (Attention) │                      │
│   └─────────────┘          └─────────────┘                      │
│         │                        │                              │
│         │                        │                              │
│   ┌─────────────┐          ┌─────────────┐                      │
│   │  MLP +      │──────────│ Transformer │ ─► bspline_transformer│
│   │  B-spline   │          │ (Attention) │                      │
│   └─────────────┘          └─────────────┘                      │
│         │                        │                              │
│         │                        │                              │
│   ┌─────────────┐          ┌─────────────┐                      │
│   │  Full KAN   │──────────│ Transformer │ ─► kan_transformer   │
│   │  (Learnable)│          │ (Attention) │                      │
│   └─────────────┘          └─────────────┘                      │
│         │                        │                              │
│         │                  ┌─────────────┐                      │
│         └──────────────────│   Mamba     │ ─► *_mamba variants  │
│                            │   (SSM)     │                      │
│                            └─────────────┘                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Six Model Variants

Variant Feedforward Sequence Purpose
mlp_transformer MLP (ReLU/GELU) Attention Baseline
bspline_transformer MLP + B-spline Attention Isolate activation effect
kan_transformer Full KAN Attention Full KAN architecture
mlp_mamba MLP (ReLU/GELU) SSM Mamba baseline
bspline_mamba MLP + B-spline SSM Activation + SSM
kan_mamba Full KAN SSM Full KAN + SSM (novel)

Key Results

60 experiments completed: 3 models × 2 tasks × 10 seeds on NVIDIA H100 80GB

Results Visualization

Experiment Results

Accuracy Comparison

Task MLP Transformer KAN Transformer B-spline Transformer
Symbolic Regression 0.0077 ± 0.0023 0.0080 ± 0.0028 0.0082 ± 0.0038
Language Modeling 10.8366 ± 0.0007 10.8373 ± 0.0006 10.8363 ± 0.0017

Training Speed Comparison

Model Speed (steps/s) Time per Experiment Slowdown vs MLP
MLP Transformer 92.4 26s 1.0× (baseline)
KAN Transformer 52.0 50s 1.78× slower
B-spline Transformer 19.6 633s 4.72× slower

Key Findings

Metric MLP KAN B-spline
Accuracy Best ~Equal ~Equal
Speed Fastest 1.78× slower 4.72× slower
Recommendation Use this If interpretability needed Not recommended

Model Comparison

Conclusions

  1. All models perform similarly on accuracy - differences are within statistical noise
  2. MLP wins on speed - fastest training with best or equal accuracy
  3. KAN is practical with efficient-kan - only 1.78× slower (vs 60,000× with pykan)
  4. B-spline provides no benefit - slowest model without accuracy gains

Quick Start

Installation

# Clone the repository
git clone https://github.com/stchakwdev/Mamba_KAN.git
cd Mamba_KAN

# Create environment
conda create -n mamba_kan python=3.10
conda activate mamba_kan

# Install PyTorch (adjust CUDA version as needed)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install package with development dependencies
pip install -e ".[dev]"

Run Quick Validation

# Quick test (1 model, 1 seed, 100 steps)
make train

# Or directly:
python scripts/run_experiment.py \
    --model mlp_transformer \
    --task symbolic \
    --seeds 1 \
    --max-steps 100 \
    --no-wandb

Run Full Comparison

# Full factorial experiment (all models, 10 seeds)
python scripts/run_full_comparison.py \
    --tasks symbolic special_functions timeseries language long_context \
    --seeds 10 \
    --max-steps 10000 \
    --output-dir ./results \
    --generate-report

Benchmark Tasks

1. Symbolic Regression (KAN-favorable)

Tests function approximation with learnable activations:

  • Basic functions: sin, cos, exp, log, sqrt
  • Special functions: Bessel (J0, J1), Legendre (P2, P3, P4)
  • Deep compositions: sin(exp(cos(x))), nested trigonometrics

2. Time Series Forecasting

Tests temporal pattern recognition:

  • Patterns: sine with trend, multi-seasonal, chaotic, AR process
  • Sequence lengths: 128, 256, 512
  • Prediction horizons: 10, 20 steps

3. Language Modeling (Mamba-favorable)

Tests long-range dependency modeling:

  • Standard sequences: 256, 512 tokens
  • Long-context: 2048, 4096 tokens
  • Mamba's O(n) complexity provides significant advantage

Project Structure

mamba_kan/
├── models/                    # Model implementations
│   ├── base.py               # Task-aware base class
│   ├── mlp_transformer.py    # MLP-Transformer (baseline)
│   ├── bspline_transformer.py # B-spline activation baseline
│   ├── kan_transformer.py    # Full KAN-Transformer
│   ├── *_mamba.py            # Mamba variants
│   └── components/
│       ├── bspline_mlp.py    # Learnable B-spline activation
│       ├── kan_layers.py     # KAN building blocks
│       ├── mamba_layers.py   # Mamba with B-spline support
│       └── transformer_layers.py
├── training/
│   ├── trainer.py            # PyTorch Lightning module
│   ├── scheduler.py          # Learning rate schedules
│   └── callbacks.py          # Training monitoring
├── analysis/
│   └── statistics.py         # Friedman, Wilcoxon, bootstrap CI
├── visualization/            # Plotting and dashboards
│   ├── plots.py              # Training curves, comparisons
│   ├── heatmaps.py           # Statistical visualizations
│   ├── animations.py         # GIF generation
│   └── dashboard.py          # Interactive HTML reports
├── data/
│   └── datasets.py           # All benchmark datasets
└── configs/
    └── base_config.py        # Configuration system

scripts/
├── run_experiment.py         # Single experiment runner
├── run_full_comparison.py    # Full factorial comparison
├── generate_assets.py        # Generate README visualizations
└── runpod_setup.sh          # Cloud GPU setup

Statistical Analysis

The project implements rigorous statistical testing following Demšar (2006):

  • Friedman Test: Non-parametric comparison across multiple classifiers
  • Wilcoxon Signed-Rank: Pairwise post-hoc comparisons
  • Holm-Bonferroni Correction: Multiple comparison adjustment
  • Bootstrap Confidence Intervals: Effect size uncertainty quantification
from mamba_kan.analysis import run_full_analysis, print_analysis_summary

results = {
    'mlp_transformer': [0.52, 0.51, 0.53, ...],  # Loss per seed
    'bspline_transformer': [0.48, 0.47, 0.49, ...],
    'kan_transformer': [0.45, 0.44, 0.46, ...],
    # ... other models
}

analysis = run_full_analysis(results)
print_analysis_summary(analysis)

Hardware Requirements

Configuration Specification
Minimum NVIDIA GPU, 8GB VRAM, 16GB RAM
Recommended RTX 3080+ or A100, 32GB RAM
Full experiments H100, 80GB VRAM

Cloud Deployment

# RunPod H100 setup
chmod +x scripts/runpod_setup.sh
./scripts/runpod_setup.sh

# Run full experiment suite
python scripts/run_full_comparison.py --task all --seeds 10

Development

# Install dev dependencies
pip install -e ".[dev]"

# Run tests
make test

# Run linting
make lint

# Format code
make format

# Generate visualizations from results
make visualize

Documentation


References

Papers

Resources

  • efficient-kan - Fast KAN implementation (used in this project, ~250× faster than pykan)
  • pykan - Official KAN implementation
  • mamba-ssm - Official Mamba implementation
  • awesome-kan - Comprehensive KAN resources

Citation

@misc{mamba_kan_2025,
    title={Mamba-KAN: A Factorial Comparison of Neural Network Architectures},
    author={Samuel T. Chakwera},
    year={2025},
    url={https://github.com/stchakwdev/Mamba_KAN},
    note={Investigating whether KAN advantages stem from B-spline activations or network topology}
}

License

MIT License - see LICENSE for details.


Back to Top

Made with PyTorch Lightning and scientific rigor

About

A rigorous 2x3 factorial comparison of neural network architectures: KAN vs MLP feedforward layers combined with Transformer vs Mamba sequence models. Investigates whether KAN advantages stem from B-spline activations or network topology.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •