JAX · Pallas · XLA

XLA-Fused Sparse Attention
Kernel Engine

XSAKE implements block-sparse attention on JAX + Pallas with HADS — a novel Head-Adaptive Dynamic Sparsity algorithm that assigns per-head sparsity ratios based on Shannon entropy, achieving 42% latency reduction while preserving output quality better than any fixed-pattern baseline.

Explore HADSSee BenchmarksTraining Replay

Key Results · NVIDIA T4 (Kaggle)

42.2%
Latency reduction (seq=4096)
713ms → 412ms on T4
59.9%
HBM memory savings (seq=4096)
1207MB → 484MB
0.089
KL divergence vs dense
best among sparse methods
+1.8%
Perplexity gap vs dense
10k steps OpenWebText

What is HADS?

Standard sparse patterns (BigBird, Longformer) assign the same block mask to every attention head. But attention heads specialize — semantic heads focus narrowly, syntactic heads attend broadly.

HADS measures per-head Shannon entropy from a calibration pass and assigns per-head sparsity ratios: focused heads skip 80–90% of KV blocks, broad heads skip only 10–20%.

# entropy → sparsity (12 heads)
H0
62%
H1
82% ← syntactic
H2
18% ← semantic
H3
71%
H4
55%
H5
44%

HADS vs Baselines (seq=2048)

MethodActive blocksLatency (ms)KL vs dense
Dense100%181.20.000
Sliding Window28%64.30.312
Longformer32%71.80.278
BigBird41%84.20.198
Random 50%50%103.70.421
HADS ✓43%105.60.089

HADS achieves the lowest KL divergence (best quality) among all sparse methods.

Architecture

XSAKE/
├── kernels/pallas/ ← Block-sparse attention (Pallas → Triton/XLA)
│ └── sparse_attention.py # Online softmax, fori_loop over KV blocks
├── kernels/hads/ ← Head-Adaptive Dynamic Sparsity (novel)
│ └── hads_pattern.py # Entropy measurement + block mask construction
├── model/ ← GPT-style transformer (Flax/linen + XSAKE attn)
├── distributed/ ← 2D mesh, pmap, NamedSharding
├── training/ ← AdamW + cosine warmup, Orbax checkpoints
├── benchmarks/ ← Latency / memory / throughput / HADS ablation
├── observability/ ← FastAPI dashboard + Prometheus + W&B
└── showcase/ ← This Next.js app (you are here)