Note: This file provides a high-level overview. For complete mathematical exposition with derivations, implementation details, and references, see the dedicated documents:
docs/EHAP.md— EHAP: Fisher EMA, OBD/OBS/Normalized scoring, blockwise exact OBS with Woodbury inverse, gradient-covariance Hessian, iterative pruning, weight compensationdocs/CORING.md— CORING: N:M structured sparsity, optimal C(M,N) mask selection, iterative swap-refine, absolute-magnitude redistribution, Ampere 2:4 layout support
Modern large language models (LLMs) like Llama 2, Mistral, and GPT-4 contain billions of parameters. During inference on consumer hardware (the target of Tensorbit Labs: 8–16 GB RAM devices), two bottlenecks dominate:
- Memory bandwidth — reading all weights from RAM/VRAM overwhelms the memory bus.
- Compute throughput — dense matrix multiplications waste cycles on near-zero weights that contribute negligibly to the output.
Structured sparsity addresses both simultaneously. By enforcing a hardware-friendly N:M pattern (e.g., 2:4 — exactly 2 non-zero values in every contiguous group of 4), the GPU can:
- Skip loading pruned weights entirely (2× bandwidth reduction for 2:4).
- Double matrix-multiply throughput via NVIDIA's Sparse Tensor Cores (Ampere/Hopper
instruction
mma.sp).
The challenge is which weights to prune. Random or naive magnitude-based pruning can remove "quiet but load-bearing" weights — parameters with small magnitudes whose removal disproportionately harms model accuracy. The Hessian-aware approach solves this.
Given a loss function L(w; D) parameterized by weights w and evaluated on dataset D, the local geometry of the loss landscape around the current weight vector is characterized by the Hessian matrix:
A second-order Taylor expansion reveals how much the loss changes when we perturb the weights by δw:
At a local minimum, ∇L = 0, and the loss change is dominated by the quadratic form:
If we prune weight w_j (i.e., set δw_j = -w_j), the loss increase is:
This is the fundamental insight: weights with small w_j^2 · H_jj can be removed
with minimal impact on the loss function. The Hessian diagonal encodes per-weight
"load-bearing capacity."
Computing the full Hessian is O(N^2) in memory and O(N^3) in time for an N-parameter model — infeasible for billion-parameter LLMs.
The Empirical Fisher Information Matrix is an approximation that replaces the Hessian with the expectation of squared gradients:
Under the assumption that the model's output distribution matches the true data distribution, the Fisher is asymptotically equivalent to the Hessian at the optimum. Crucially, we take only the diagonal — O(N) memory:
The EHAP (Efficient Hessian-Aware Pruning) sensitivity score for weight w_i is:
Where:
| Symbol | Meaning | Default |
|---|---|---|
| w_i | Weight magnitude | — |
| F_ii | Diagonal Fisher Information | Computed from gradients |
| λ (lambda) | Damping factor for numerical stability | 0.01 |
The damping term λ prevents division-like instabilities when F_ii ≈ 0 (weights that receive near-zero gradients but must be kept for architectural reasons, like bias terms or embedding entries).
Weights with the lowest s_i are pruned.
During training or fine-tuning, Fisher information is accumulated incrementally:
Initialize: F = zero vector of size N
For each batch:
Compute gradient g = ∇L(w)
Update: F[i] += α · g[i]^2 # α = 1 / accumulation_steps
This is an exponential running average with decay controlled by accumulation_steps.
The implementation in include/tensorbit/core/ehap.hpp (EHAPPruner::accumulate_fisher) supports both:
- GPU path: Launches
fisher_accumulate_kernel(1 thread per element,fmaffused multiply-add for precision). - CPU path: Element-wise loop with
__restrict__annotations for autovectorization.
Once Fisher is accumulated, importance scores couple weights with curvature:
For each weight w[i]:
if Fisher diagonal available:
s[i] = w[i]^2 · (F[i] + damping)
else (magnitude fallback):
s[i] = w[i]^2
GPU: ehap_importance_kernel — one thread per weight, zero shared memory.
CPU: Plain loop, vectorizable.
Given a target sparsity ratio r (fraction of weights to keep):
- Find the (1-r) · N percentile of importance scores via
std::nth_element(O(N)). - All weights with score below this threshold are marked for pruning.
- Output: binary mask M where M[i] = 1 means "keep."
Global (unstructured) sparsity — dropping the k least-important weights wherever they are — produces irregular data access patterns that GPUs cannot accelerate. Every pruned zero still occupies memory and must be skipped at runtime, incurring warp-divergence penalties.
N:M structured sparsity constrains the pattern: divide weights into contiguous groups of M, keep exactly N per group. This guarantees:
- Regular memory access — hardware can predict which elements are zero.
- No warp divergence — all threads in a warp follow the same index pattern.
- 2× throughput on A100/H100 when N=2, M=4 (supported by
mma.spinstruction).
Given importance scores s[0..N_elements-1] and N:M pattern:
For each group g in [0, N_elements/M):
base = g · M
Find top-N indices among s[base .. base+M-1]
Emit mask byte: bit i = 1 if i is in top-N set
The 2:4 kernel (nm_mask_2_4_kernel) is optimized for Ampere's native instruction:
| Threads per block | Shared memory | Per-thread work | Occupancy bottleneck |
|---|---|---|---|
| 256 | 0 bytes | One group per thread | Register pressure (~18 regs) |
Each thread loads 4 importance values into registers, finds the top-2 indices via a fixed comparison tree (fully unrolled, no branches after compilation), and writes a packed byte. This achieves near-100% theoretical occupancy on SM80/SM90.
The generic kernel (nm_mask_generic_kernel) handles arbitrary N:M patterns for M ≤ 32:
| Threads per block | Shared memory | Per-thread work | Time complexity |
|---|---|---|---|
| M (up to 32) | ~256 bytes (32×float + 32×int) | Rank computation | O(M^2) |
Algorithm:
- Each of M threads loads one importance value into
__shared__ float s_vals[M]. - Each thread counts how many elements have strictly higher value → its rank.
- Tie-breaking: equal values are resolved by lower index winning (deterministic).
- Thread 0 assembles the mask byte from ranks: if rank < N, bit is set.
For non-GPU execution or double-precision tensors:
- Copy M elements per group into a
std::pair<float, int>vector (value, original index). - Use
std::nth_elementto partition top-N to the front (O(M log M) per group, but M is small). - Assemble mask byte.
The mask is applied element-by-element:
For each weight w[i]:
group = i / M
offset = i % M
if mask_byte[group] bit offset == 0:
w[i] = 0
GPU: apply_mask_kernel — 1 thread per element, one division + one bit test + one
conditional store per thread. Divergence is bounded because both paths (keep vs prune)
are trivial.
The count is computed analytically — no runtime overhead:
This is exact because validate_config ensures the tensor size is divisible by M.
.safetensors ──► [EHAP] ──► [CORING] ──► .tb file
(dense weights) │ │ (pruned + masks)
│ │
Fisher diagonal N:M bitmask
(O(N) memory) (N/M bytes)
-
Monotonicity: If w_a has higher importance than w_b, CORING will never prefer w_b over w_a within the same group. The superposition of EHAP importance and CORING structural constraints is monotonic.
-
Sparsity Guarantee: After the pipeline, every contiguous group of M weights contains exactly N non-zero values. The ratio N/M is the structural sparsity of the model.
-
Memory Footprint: During pruning, peak memory is O(N) for weights + O(N) for Fisher diagonal + O(N/M) for masks. For a 7B parameter model:
- FP32 weights: 28 GB
- FP32 Fisher: 28 GB (during pruning only)
- 2:4 mask: 1.75 GB
- Total peak: ~58 GB (fits on a single A100-80GB)
-
Numerical Stability: The damping term λ prevents the importance score from collapsing to zero when Fisher information is near-zero (common in embedding layers or frozen parameters).
Consider two weights in the same 2:4 group:
| Weight | |w| | F_ii | w^2 · F_ii | Magnitude rank | EHAP rank | |--------|-----|------|------------|----------------|-----------| | w_a | 0.05 | 100.0 | 0.25 | 2nd (pruned) | 1st (kept) | | w_b | 0.10 | 1.0 | 0.01 | 1st (kept) | 2nd (pruned) |
Magnitude pruning keeps w_b (larger |w|) and prunes w_a. But w_a's high Fisher value (100.0) indicates that small changes to w_a cause large changes in loss — it is load-bearing. Pruning it would tank accuracy.
EHAP correctly identifies w_a as critical despite its small magnitude. This is the essence of Hessian-aware pruning: coupling size (magnitude) with sensitivity (curvature) to make informed decisions.
| Kernel | Arithmetic Intensity | Bound By | Theoretical Throughput |
|---|---|---|---|
| fisher_accumulate | 2 FLOP/element | Memory (HBM2e, 2 TB/s) | 500M elements/ms |
| ehap_importance | 3 FLOP/element | Memory | 333M elements/ms |
| nm_mask_2_4 | 4 FLOP/element | Compute (312 TFLOPS) | 78B elements/s |
| nm_mask_generic | O(M^2) FLOP/element | Compute | M=4: 1.3B elements/s |
| apply_mask | 0 FLOP/element (pure I/O) | Memory | 500M elements/ms |
The mask kernels are compute-bound on A100, not memory-bound, which is ideal — they don't compete with weight I/O for HBM bandwidth during the pruning phase.
| Kernel | Shared Memory/Block | Max Blocks/SM (A100, 164 KB L1/SHMEM) |
|---|---|---|
| nm_mask_generic | 256 bytes (s_vals[32] + s_ranks[32]) | 256+ (not a limiting factor) |
| All others | 0 bytes | Limited by registers/threadblocks only |
- LeCun, Denker, & Solla (1990). "Optimal Brain Damage." — Classical OBD framework using diagonal Hessian.
- Hassibi & Stork (1993). "Optimal Brain Surgeon." — Full Hessian inverse for pruning. Impractical for LLMs but the theoretical foundation.
- Theis et al. (2018). "Faster Gaze Prediction with Dense Networks and Fisher Pruning." — Introduced diagonal Fisher for modern CNNs.
- NVIDIA (2021). "Accelerating Inference with Sparsity Using the NVIDIA Ampere Architecture." — 2:4 sparse tensor core specification.
- Mishra et al. (2021). "Accelerating Sparse Deep Neural Networks." — N:M transposable fine-grained sparsity (hardware motivation for CORING).