Reflective Masking: диффузионные модели учатся исправлять свои ошибки
Авторегрессионные языковые модели рассуждают последовательно: сгенерировала токен — добавила в контекст — генерируй следующий. Если модель ошиблась на третьем токене, она может попытаться исправиться позже, но старые ошибки остаются в контексте, как чернильные кляксы на бумаге.
Диффузионные модели работают иначе. Они «разворачивают» текст из шума — за один проход переписывают весь canvas целиком. Это даёт свободу редактировать уже сгенерированные фрагменты, но до недавнего времени у них не было механизма выбирать, что исправлять. Модель либо редактировала всё, либо не редактировала ничего.
Исследователи из University of Maryland, Virginia Tech, Intuit, UC Davis и MBZUAI предложили Reflective Masking (RM) — метод, который превращает масочную диффузионную модель (Mask Diffusion Model, MDM) в многоходового ревизора. Теперь модель сама решает, какие токены достаточно надёжны, чтобы оставить, а какие — перегенерировать.
Проблема: MDMs умеют редактировать, но не умеют выбирать
Масочные диффузионные модели генерируют текст, последовательно раскрывая замаскированные токены. На каждом шаге денойзинга модель «видит» все токены canvas — и может редактировать уже раскрытые. Проблема в том, что решение «что редактировать» жёстко закодировано в архитектуре: модель не может выбрать перегенерировать конкретный токен, если текущий шаг его уже раскрыл.
Результат: MDMs закрепляют ранние ошибки. Модель «утвердила» неверный токен на шаге N, и на шаге N+10 у неё нет механизма вернуться и исправить именно его. Она может только перегенерировать весь canvas заново — что отбрасывает всю накопленную работу.
Это фундаментальное ограничение авторегрессионных моделей тоже существует, но там есть обходные пути: chain-of-thought prompting, self-repair, перегенерация с более строгим семплингом. Для MDMs таких обходных путей не было.
Reflective Masking: три действия на каждом шаге
Reflective Masking вводит три возможных действия для каждой позиции токена на каждом шаге денойзинга:
Reveal — токен достаточно уверен, оставляем как есть. Модель «утверждает» свой текущий прогноз.
Reflectively Mask — модель не уверена в токене, помечает его обратно в MASK и даёт себе ещё одну попытку на следующем шаге. Ключевое отличие от обычного маскирования: решение маскировать принимает сама модель, а не архитектура.
Reserve — токен «замораживается» до следующего шага без изменения. Это даёт модели время «подумать», не принимая окончательного решения.
Получается диффузионный аналог chain-of-thought, только вместо «писать дальше» модель «переписывает избранное». CoT думает, продолжая. RM думает, решая, что исправить.
History Reference: память без параметров
Второй ключевой компонент — History Reference (HR). Проблема многоходовой ревизии в том, что на шаге T модель должна «помнить», что она предсказывала на шагах T-1, T-2, и так далее. Если модель на шаге T-1 уже перегенерировала токен X, а на шаге T она снова видит X в контексте — она не должна пытаться его снова.
History Reference решает это без дополнительных параметров. На каждом шаге денойзинга состояния модели накапливаются и прокладываются к следующему шагу через History Embedding Rotation (HER) — механизм, который «прокручивает» эмбеддинги истории так, что модель всегда видит свой полный трек денойзинга. По сути это parameter-free attention over denoising trajectory: вместо того чтобы хранить все предыдущие hidden states вKV-кеше, модель получает компактное представление истории через ротацию эмбеддингов. Это критично для масштабирования — длинные траектории ревизии не переполняют контекст.
По сути, это параметр-свободная память denoising trajectory. Модель видит, что она уже пыталась предсказать, что уже исправляла, и не повторяет одни и те же ошибки.
Sudoku: от 82% до 93% точности
Первая задача — структурированная коррекция ошибок на судоку. Это идеальный бенчмарк для тестирования ревизии: модель получает сетку 9×9 с 4–20 испорченными ячейками и должна итеративно исправлять ошибки.
Используется крошечная MDM — всего 0.81M параметров, обученная с нуля. Никакого предобученного диффузионного LM, никакой дополнительной архитектуры.
Результаты на валидации по 1000 играм:
| Метод | Exact Accuracy | Valid Rate | Replay Mistake % | Conflict Cells/board |
|---|---|---|---|---|
| RM (no HR) | 82.4% | 86.6% | 0.57% | 0.578 |
| RM + HR | 91.4% (+9.0) | 91.8% (+5.2) | 0.07% (-0.50) | 0.300 (-0.278) |
| RM + HR + decay | 89.4% (+7.0) | 89.6% (+3.0) | 0.07% (-0.50) | 0.362 (-0.216) |
| RM + HR + decay + HER | 93.4% (+11.0) | 93.6% (+7.0) | 0.03% (-0.54) | 0.236 (-0.342) |
HR резко сокращает повторяющиеся ошибки — модель перестаёт наступать на одни и те же грабли. HER (History Embedding Rotation) добавляет ещё несколько процентных пунктов, позволяя истории «прокручиваться» через эмбеддинговое пространство.
Сравнение с DiffusionGemma
Интересно сравнить с результатами Google DiffusionGemma на том же судоку. По данным model card DiffusionGemma, точность решения судоку вырастает с 18% при one-shot до 89.5% чистой ревизей за несколько шагов, и с 1.5% до 89.5% после fine-tuning на 4000 шагов.
Reflective Masking достигает 93.4% exact accuracy с 0.81M параметров, обученной с нуля. Это даже выше, чем fine-tuned DiffusionGemma — при том, что модель в ~1000 раз меньше и не требует предобученного диффузионного LM.
Но ключевое отличие: DiffusionGemma работает только с текстом. Reflective Masking одинаково применяется к Sudoku, text reasoning и image editing — модальность, которую DiffusionGemma не поддержирует.
Image editing: локальная ревизия с инструкцией
С 7B multimodal backbone (Lumina-DiMOO) Reflective Masking локализует область редактирования и меняет только её, оставляя остальное изображение нетронутым. Это принципиально отличается от генерации с нуля: модель получает исходное изображение и инструкцию (например, «сделай деревья слева оранжевыми»), находит нужную область и ревизурует только пиксели внутри неё.
Результат — инструкция выполняется точно, а не приблизительно. Baseline-методы (Lumina, Lumina SFT) склонны «интерпретировать» инструкцию слишком широко или пропускать детали. Reflective Masking thanks to selective re-masking работает локально и точно.
Text reasoning: MATH500, MBPP, ARC-Challenge
На открытых задачах — математика и код — RM также показывает улучшения. Бэкбон — LLaDA, диффузионная модель без предобучения на code/math.
| Benchmark | Base LLaDA | Vanilla SFT | RM (Ours) | Δ vs SFT |
|---|---|---|---|---|
| MATH500 | 19.4% | 22.4% | 24.8% | +2.4 |
| MBPP (code) | 28.0% | 30.6% | 39.4% | +8.8 |
| ARC-Challenge | 73.7% | 81.3% | 86.1% | +4.8 |
Самый впечатляющий результат — MBPP, где RM даёт +8.8 п.п. против SFT. Это говорит о том, что способность «откатить» неверный код и попробовать заново особенно ценна в задачах генерации, где ошибка в одном месте может сломать всю программу.
На Minerva MATH (по предметным категориям) RM улучшает почти каждую рубрику: от +0.59% в алгебре до +4.26% в теории чисел. Сильнее всего выигрыш на категориях, где есть промежуточные шаги, которые можно проверить и исправить.
Как работает обучение
RM обучается офлайн, без online rollouts. Из чистого target сэмплируется маска, делается один forward pass MDM, и из распределения модели сэмплируются правдоподобные неправильные токены — так строится псевдо-траектория, соответствующая распределению модели. Это важно: в отличие от supervised fine-tuning, где нужны ground-truth правильные траектории, RM генерирует свои собственные «неправильные» примеры и учится на них — какие токены стоит ревизовать, а какие оставлять.
Три per-token loss учат модель трем действиям. Reveal loss: если модель предсказала правильный токен — reward за утверждение. RM loss: если модель предсказала неправильный видимый токен — penalty за отсутствие маскирования, reward за маскирование. Keep loss: если токен правильный и видимый — penalty за ненужное маскирование.
Почему офлайн обучение критично? Online rollouts (когда модель реально генерирует, получает reward, и обновляется) требуют десятки тысяч взаимодействий и значительных вычислительных затрат. Офлайн подход с псевдо-траекториями даёт тот же сигнал за долю времени. История хранится в виде эмбеддингов на каждом шаге — это и есть то, что передаётся в следующий forward pass.
Вся процедура обучения укладывается в ~5 часов на 2×H100. Это существенно дешевле, чем fine-tuning DiffusionGemma на 4000 шагов, который требует значительно больших вычислительных ресурсов.
Почему это важно
Reflective Masking решает фундаментальную проблему MDMs: невозможность выборочной ревизии. До этого диффузионные модели могли редактировать весь canvas, но не могли выбрать, что именно переделать. Теперь маскирование становится решением модели — «доверяю ли я этому токену достаточно, чтобы утвердить его?»
Для практических применений это означает, что диффузионные модели смогут автономно исправлять свои ошибки в задачах программирования, математики, логического вывода — без человеческого вмешательства и без перегенерации с нуля.
Диффузионные языковые модели больше не обречены генерировать текст вслепую. Теперь они могут рассуждать, исправляя себя.
Часто задаваемые вопросы
Чем Reflective Masking отличается от chain-of-thought?
Chain-of-thought в авторегрессионных моделях — это «писать дальше». Модель добавляет новые токены рассуждений, старые ошибки остаются в контексте. Reflective Masking — это «переписывать избранное». Модель стирает неуверенные токены и заменяет их, так что ошибки не накапливаются.
Почему History Reference — это прорыв?
History Reference решает проблему «забывания» в многоходовой ревизии. Без него модель на шаге T не знает, что она предсказывала на шаге T-1. С ним — видит полный трек и не повторяет ошибки. При этом механизм не добавляет параметров: это просто накопление состояний с поворотом эмбеддингов.
Можно ли использовать RM с любой MDM?
Да. Reflective Masking — это post-training метод, который применяется к существующим MDMs без изменения архитектуры. Эксперименты проведены на LLaDA и Lumina-DiMOO, но процедура универсальна.
Итог
Reflective Masking превращает масочные диффузионные модели в многоходовых ревизоров. Модель сама решает, какие токены перегенерировать, на каждом шаге денойзинга. History Reference даёт ей память о прошлых попытках — без параметров и без дополнительных attention sequences.
Результат: 93.4% на судоку с 0.81M параметров (vs 89.5% у fine-tuned DiffusionGemma), улучшения на MATH500, MBPP и ARC-Challenge, и применимость к image editing — модальность, недоступная текстовым диффузионным моделям.
CoT думает, продолжая. RM думает, исправляя.