MISA: 긴 문맥 LLM 추론을 위한 인덱서 희소 어텐션의 혼합
MISA: Mixture of Indexer Sparse Attention for Long-Context LLM Inference
DeepSeek Sparse Attention (DSA)는 학습된 토큰 단위 인덱서를 도입하여 세밀한 추론 시간 희소 어텐션 분야에서 최첨단 기술을 선보입니다. 이 인덱서는 모든 접두사 토큰에 대한 점수를 매기고, 주요 어텐션에 가장 관련성이 높은 토큰을 선택합니다. 표현력을 유지하기 위해, 인덱서는 동일한 선택된 토큰 집합을 공유하는 많은 쿼리 헤드(예: DeepSeek-V3.2의 경우 64개)를 사용합니다. 이러한 멀티 헤드 설계는 특히 긴 문맥에서 인덱서의 주요 비용 요소입니다. 우리는 DSA 인덱서를 대체할 수 있는 MISA (Mixture of Indexer Sparse Attention)를 제안합니다. MISA는 인덱서 헤드를 전문가 혼합 풀로 취급합니다. 경량 라우터는 저렴한 블록 수준 통계를 사용하여 쿼리에 따라 선택된 소수의 활성 헤드 집합을 선택하며, 선택된 헤드만 토큰 수준의 계산 집약적인 점수 매기기를 수행합니다. 이를 통해 원래 인덱서 풀의 다양성을 유지하면서, 각 쿼리에 대한 비용을 모든 헤드를 사용하여 모든 접두사 토큰을 점수 매기는 것에서, 소수의 라우팅된 헤드만 사용하여 점수 매기는 것으로 줄입니다. 또한, 라우팅된 결과를 사용하여 확장된 후보 집합을 유지하고, 원래 DSA 인덱서를 사용하여 다시 순위를 매겨 최종 선택된 토큰을 거의 정확하게 복구하는 MISA의 계층적 변형을 추가적으로 제안합니다. MISA는 추가적인 학습 없이 8개의 활성 헤드만 사용하여 DeepSeek-V3.2 및 GLM-5 모델에서 LongBench 벤치마크에서 밀집 DSA 인덱서와 동일한 성능을 보이면서, 각각 8배와 4배 적은 수의 인덱서 헤드를 사용하고, 평균적으로 HISA보다 더 뛰어난 성능을 보입니다. 또한, 128K 토큰의 문맥에서도 완전히 녹색의 Needle-in-a-Haystack 히트맵을 유지하며, DSA 인덱서가 선택한 토큰의 92% 이상을 각 레이어에서 복구합니다. 당사의 TileLang 커널은 단일 NVIDIA H200 GPU에서 DSA의 원래 인덱서 커널보다 약 3.82배 빠른 속도를 제공합니다.
DeepSeek Sparse Attention (DSA) sets the state of the art for fine-grained inference-time sparse attention by introducing a learned token-wise indexer that scores every prefix token and selects the most relevant ones for the main attention. To remain expressive, the indexer uses many query heads (for example, 64 on DeepSeek-V3.2) that share the same selected token set; this multi-head design is precisely what makes the indexer the dominant cost on long contexts. We propose MISA (Mixture of Indexer Sparse Attention), a drop-in replacement for the DSA indexer that treats its indexer heads as a pool of mixture-of-experts. A lightweight router uses cheap block-level statistics to pick a query-dependent subset of only a few active heads, and only those heads run the heavy token-level scoring. This preserves the diversity of the original indexer pool while reducing the per-query cost from scoring every prefix token with every head to scoring it with only a handful of routed heads, plus a negligible router term computed on a small set of pooled keys. We further introduce a hierarchical variant of MISA that uses the routed pass to keep an enlarged candidate set and then re-ranks it with the original DSA indexer to recover the final selected tokens almost exactly. With only eight active heads and no additional training, MISA matches the dense DSA indexer on LongBench across DeepSeek-V3.2 and GLM-5 while running with eight and four times fewer indexer heads respectively, and outperforms HISA on average. It also preserves fully green Needle-in-a-Haystack heatmaps up to a 128K-token context and recovers more than 92% of the tokens selected by the DSA indexer per layer. Our TileLang kernel delivers roughly a 3.82 times speedup over DSA's original indexer kernel on a single NVIDIA H200 GPU.
No Analysis Report Yet
This paper hasn't been analyzed by Gemini yet.
Log in to request an AI analysis.