DashAttention: адаптивная разреженная иерархическая атенция для длинных контекстов
Что если можно было бы сократить вычисления внимания на 75%, не потеряв при этом в точности? Новая работа исследователей из Tsinghua University, Carnegie Mellon, University of Edinburgh и Sapienza University of Rome показывает, что это не фантастика. Их метод DashAttention пересматривает сам принцип иерархической разреженной атенции, заменяя жёсткий top-k отбор на адаптивное α-entmax преобразование — и добивается при этом ускорения до 3.36× над FlashAttention-3.
Что такое DashAttention
DashAttention (Differentiable and Adaptive Sparse Hierarchical Attention) — это многостадийный механизм внимания, который сохраняет точность полной softmax-атенции при высокой разреженности. В отличие от существующих иерархических методов вроде NSA и InfLLMv2, DashAttention не фиксирует количество релевантных токенов заранее, а адаптивно выбирает переменное число ключ-значение (KV) блоков для каждого запроса.
Ключевая инновация — использование α-entmax вместо top-k на этапе грубого роутинга. Поскольку α-entmax полностью дифференцируем, вся иерархия остаётся сквозно обучаемой: градиенты свободно текут от тонкой softmax-атенции второй стадии обратно к решениям грубого роутинга первой стадии. Это решает фундаментальную проблему существующих методов, где жёсткий top-k разрывает дифференцируемый путь и навязывает фиксированный бюджет внимания независимо от содержания запроса.
Почему top-k разрушает градиенты
Современные иерархические методы разреженной атенции работают по двухстадийной схеме: сначала грубые оценки внимания отбирают top-k KV-блоков, затем к отобранным токенам применяется точная softmax-атенция. Этот подход интуитивно разумен, но несёт три скрытых дефекта.
Во-первых, top-k предполагает, что для любого запроса релевантно одинаковое количество токенов — это фиксированный бюджет, не зависящий от сложности задачи. В реальности одни запросы требуют широкого контекста, а другие — узкой концентрации на конкретном фрагменте. Во-вторых, операция top-k не дифференцируема, что перекрывает градиентный поток между грубой и тонкой стадиями. В-третьих, существующие методы агрегируют головы внимания через softmax перед top-k отбором, что позволяет дисперсии (размытию вероятностей) возвращаться на этапе агрегации.
DashAttention устраняет все три проблемы единым архитектурным решением. Вместо top-k метод использует α-entmax — адаптивно разреженное распределение, чья поддержка (количество ненулевых элементов) определяется самими входными данными, а веса остаются гладкими и дифференцируемыми. Это означает, что модель сама решает, сколько чанков ей нужно — иногда двадцать, иногда два — и учится этому решению напрямую через градиентный спуск, без промежуточных эвристик.
Трёхстадийная архитектура
DashAttention декомпозирует внимание на три последовательных этапа, каждый из которых решает свою задачу эффективности.
Стадия 0 строит краткие описания (summaries) для каждого чанка из 64 токенов, используя локальную SDPA-атенцию внутри чанка. Это интерпретируется как обучаемая summary-голова, которая читает каждый фрагмент изнутри — выразительнее mean-pooling в MoBA и InfLLMv2 и проще MLP-компрессии в NSA. Здесь важен архитектурный выбор: вместо того чтобы сжимать чанк внешним MLP или усреднять, модель учится «резюмировать» чанк через внутреннее внимание, что даёт контекстуализированные представления, а не статические эмбеддинги.
Стадия 1 применяет α-entmax к чанковым score-векторам, получая адаптивно разреженное распределение вероятностей по чанкам. Параметр α контролирует степень разреженности: при высоком α распределение становится щедрым (много чанков получают ненулевой вес), при низком — жёстко разреженным. Критически важно, что α-entmax выдаёт гладкие вероятности с дифференцируемой поддержкой, в отличие от жёсткого отсечения top-k. Практически это реализовано через AdaSplash-2 — оптимизированное GPU-ядро, которое вычисляет α-entmax с защищёнными higher-order обновлениями, избегая нестабильности при малых градиентах.
Стадия 2 выполняет softmax-атенцию на полном разрешении, но только внутри отобранных чанков. При этом логиты смещаются (bias) на основе весов роутера из Стадии 1, что сохраняет информацию о глобальной релевантности чанка даже при локальном softmax. Вся реализация совместима с FlashAttention: Стадия 1 использует AdaSplash-2, Стадия 2 — стандартный FlashAttention-ядро с битовой маской активных чанков. Битовая маска упаковывается в 64-битные слова, что позволяет ядру за один проход обходить только активные чанки, не материализуя индексы явно — в отличие от InfLLMv2, которая тратит время на per-query индексацию между стадиями оценки и внимания.
Теория: почему non-dispersive важно
Недавние теоретические работы показали, что softmax-атенция страдает от дисперсии (dispersion) в длинных контекстах: энтропия распределения внимания растёт как O(log n) с длиной последовательности, делая моделирование дальних зависимостей всё более неточным. Представьте себе фонарь, который вместо узкого луча раскидывает свет по всей комнате — чем больше комната, тем тусклее каждая точка. Разреженные методы с top-k ограничивают энтропию константой O(log k), но существующие иерархические схемы компрометтируют это свойство на этапе агрегации голов.
Авторы DashAttention формализуют это наблюдение и доказывают, что их метод является non-dispersive — энтропия не растёт с длиной контекста, поскольку агрегация голов выполняется через entmax, а не softmax. В отличие от softmax, который всегда даёт положительный вес каждому элементу, entmax обнуляет нерелевантные веса полностью, предотвращая «засветку» от второстепенных токенов. Это теоретическое гарантия напрямую транслируется в практическое преимущество: на задачах длинного контекста DashAttention устойчиво превосходит базовые линии, особенно в режимах высокой разреженности.
Эксперименты: точность и скорость
Оценка проводилась на моделях MiniCPM-4 размером 1B, 3B и 8B параметров после долгого дообучения на контекстах 16K токенов. Все три модели обучались на кластере из 32 GPU NVIDIA A800 с процессорами Intel Xeon Platinum 8470, используя фреймворк Megatron и смешанную точность bfloat16. Для долгого контекстного дообучения использовался набор InfLLM-5B с длиной контекста 16K, после чего следовал короткий этап supervised fine-tuning на оригинальных данных MiniCPM-4. Для сравнения использовались полная атенция (FullAttn), NSA и InfLLMv2 при 75% разреженности. Тестирование выполнялось на бенчмарках RULER и HELMET — стандартных наборах для оценки длинного контекста.
На RULER DashAttention показал результаты на уровне полной атенции и заметно превзошёл NSA и InfLLMv2, особенно на сложных retrieval-задачах. На HELMET картина аналогична: при низкой и умеренной разреженности DashAttention слегка превышает полную атенцию, а при высокой разреженности разрыв с базовыми линиями резко увеличивается. При ~90% разреженности DashAttention сохраняет 39.4% общей точности, превосходя InfLLMv2 на 9 процентных пунктов и NSA на 19 пунктов.
Скорость инференса измерялась на одном GPU NVIDIA GH200 с CUDA 13.2, PyTorch 2.10.0 и Triton 3.6.0. DashAttention оказался быстрее всех конкурентов на каждой точке работы: ускорение над плотным FlashAttention-3 колеблется от 1.34× до 3.09× в зависимости от длины последовательности и разреженности. Преимущество над InfLLMv2 и NSA максимально при плотных настройках, где накладные расходы top-k амортизируются хуже всего, и сохраняется при любой разреженности. В режиме декодирования с batch size 24 DashAttention достигает 3.36× ускорения над FlashAttention-3 и 1.35× над InfLLMv2.
Интересно, что на коротких задачах DashAttention не уступает полной атенции и даже слегка превосходит её в некоторых случаях — вероятно, благодаря регуляризующему эффекту разреженности, который подавляет шумные ассоциации в плотном внимании. На стандартных бенчмарках общего назначения (MMLU, GSM8K, MATH, HumanEval) DashAttention показал среднюю точность 59.5% против 59.5% у полной атенции, 59.2% у NSA и 59.1% у InfLLMv2 — то есть разреженность не ухудшила способности модели к рассуждениям и кодированию.
Динамическое распределение разреженности
Одно из самых интересных свойств DashAttention — способность динамически перераспределять вычислительный бюджет между слоями в зависимости от входных данных. Анализ разреженности по слоям на входе длиной 16K из RULER-SG1 показывает, что нижние слои модели склонны к более плотному вниманию (собирают глобальный контекст), а верхние слои — к более разреженному (фокусируются на специфических деталях).
Это поведение возникает естественно из геометрии score-векторов α-entmax, без явного контроля. В отличие от NSA и InfLLMv2, где разреженность фиксирована гиперпараметрами, DashAttention адаптирует глубину внимания к структуре конкретного входа — и делает это отдельно для каждой головы в каждом слое.
Почему это важно сейчас
Контекстные окна современных LLM выросли от 4K до 2M токенов за последние два года, но вычислительная сложность внимания растёт квадратично. Разреженные методы — не роскошь, а необходимость для экономически осмысленного развёртывания длинноконтекстных моделей. Однако существующие решения компрометтируют либо точность (жёсткая фиксация бюджета), либо обучаемость (разрыв градиентов), либо эффективность (накладные расходы top-k).
DashAttention демонстрирует, что эти компромиссы не фундаментальны. Полная дифференцируемость, адаптивная разреженность и non-dispersive свойства сочетаются в единой архитектуре, которая одновременно проще в обучении, точнее в задачах и быстрее в инференсе. Особенно впечатляет доминирование DashAttention на Pareto-фронтире точность-разреженность: метод не просто быстрее или точнее, а строго лучше по обоим критериям одновременно.
Для практиков важен и инженерный аспект: реализация на Triton означает, что DashAttention может быть интегрирован в существующие пайплайны без радикальной перестройки инфраструктуры. Единственное существенное ограничение на текущий момент — отсутствие интеграции в production-серверы вроде vLLM и SGLang, но авторы открыто заявляют о планах по устранению этого пробела.
Часто задаваемые вопросы
Можно ли использовать DashAttention с существующими моделями?
Да, архитектура совместима с любой моделью на базе GQA (Grouped Query Attention). Авторы интегрировали DashAttention в MiniCPM-4 через долгое дообучение на 16K контекстах, но теоретически метод можно адаптировать и к другим базовым моделям с минимальными изменениями конфигурации.
Что такое α-entmax и почему его раньше не использовали для атенции?
α-entmax — это параметризованное обобщение softmax, которое выдаёт разреженные вероятностные распределения с адаптивной поддержкой. Он известен в сообществе structured prediction с 2019 года, но его применение для иерархической атенции требовало совместимой GPU-реализации и теоретического обоснования non-dispersion. DashAttention предоставляет и то, и другое.
Каковы ограничения метода?
Ядра DashAttention пока не интегрированы в production-фреймворки вроде vLLM и SGLang — это оставлено для будущей работы. Кроме того, текущая оценка ограничена архитектурой MiniCPM-4; поведение на других моделях (Llama, Qwen, Mistral) требует дополнительной валидации. Наконец, α-entmax добавляет один гиперпараметр (фактор γ), который нужно подбирать под целевую разреженность.
Итог
DashAttention пересматривает фундаментальное предположение иерархической разреженной атенции: что количество релевантных токенов можно зафиксировать заранее. Заменяя жёсткий top-k на адаптивный α-entmax, метод достигает полной дифференцируемости, динамического распределения разреженности и теоретически обоснованного non-dispersive поведения. Практический результат — точность полной атенции при 75% разреженности и ускорение до 3.36× над FlashAttention-3. Для разработчиков и исследователей, работающих с длинными контекстами, это означает, что компромисс между скоростью и качеством становится заметно мягче.