JAX · Pallas · XLA
XLA-Fused Sparse Attention
Kernel Engine
XSAKE implements block-sparse attention on JAX + Pallas with HADS — a novel Head-Adaptive Dynamic Sparsity algorithm that assigns per-head sparsity ratios based on Shannon entropy, achieving 42% latency reduction while preserving output quality better than any fixed-pattern baseline.
Key Results · NVIDIA T4 (Kaggle)
42.2%
Latency reduction (seq=4096)
713ms → 412ms on T4
59.9%
HBM memory savings (seq=4096)
1207MB → 484MB
0.089
KL divergence vs dense
best among sparse methods
+1.8%
Perplexity gap vs dense
10k steps OpenWebText
What is HADS?
Standard sparse patterns (BigBird, Longformer) assign the same block mask to every attention head. But attention heads specialize — semantic heads focus narrowly, syntactic heads attend broadly.
HADS measures per-head Shannon entropy from a calibration pass and assigns per-head sparsity ratios: focused heads skip 80–90% of KV blocks, broad heads skip only 10–20%.
# entropy → sparsity (12 heads)
H062%
H182% ← syntactic
H218% ← semantic
H371%
H455%
H544%
HADS vs Baselines (seq=2048)
HADS achieves the lowest KL divergence (best quality) among all sparse methods.
Architecture
XSAKE/
├── kernels/pallas/ ← Block-sparse attention (Pallas → Triton/XLA)
│ └── sparse_attention.py # Online softmax, fori_loop over KV blocks
├── kernels/hads/ ← Head-Adaptive Dynamic Sparsity (novel)
│ └── hads_pattern.py # Entropy measurement + block mask construction
├── model/ ← GPT-style transformer (Flax/linen + XSAKE attn)
├── distributed/ ← 2D mesh, pmap, NamedSharding
├── training/ ← AdamW + cosine warmup, Orbax checkpoints
├── benchmarks/ ← Latency / memory / throughput / HADS ablation
├── observability/ ← FastAPI dashboard + Prometheus + W&B
└── showcase/ ← This Next.js app (you are here)