Reflective Masking: диффузионные модели учатся исправлять свои ошибки

Reflective Masking: диффузионные модели учатся исправлять свои ошибки

Авторегрессионные языковые модели рассуждают последовательно: сгенерировала токен — добавила в контекст — генерируй следующий. Если модель ошиблась на третьем токене, она может попытаться исправиться позже, но старые ошибки остаются в контексте, как чернильные кляксы на бумаге.

Диффузионные модели работают иначе. Они «разворачивают» текст из шума — за один проход переписывают весь canvas целиком. Это даёт свободу редактировать уже сгенерированные фрагменты, но до недавнего времени у них не было механизма выбирать, что исправлять. Модель либо редактировала всё, либо не редактировала ничего.

Исследователи из University of Maryland, Virginia Tech, Intuit, UC Davis и MBZUAI предложили Reflective Masking (RM) — метод, который превращает масочную диффузионную модель (Mask Diffusion Model, MDM) в многоходового ревизора. Теперь модель сама решает, какие токены достаточно надёжны, чтобы оставить, а какие — перегенерировать.

Проблема: MDMs умеют редактировать, но не умеют выбирать

Масочные диффузионные модели генерируют текст, последовательно раскрывая замаскированные токены. На каждом шаге денойзинга модель «видит» все токены canvas — и может редактировать уже раскрытые. Проблема в том, что решение «что редактировать» жёстко закодировано в архитектуре: модель не может выбрать перегенерировать конкретный токен, если текущий шаг его уже раскрыл.

Результат: MDMs закрепляют ранние ошибки. Модель «утвердила» неверный токен на шаге N, и на шаге N+10 у неё нет механизма вернуться и исправить именно его. Она может только перегенерировать весь canvas заново — что отбрасывает всю накопленную работу.

Это фундаментальное ограничение авторегрессионных моделей тоже существует, но там есть обходные пути: chain-of-thought prompting, self-repair, перегенерация с более строгим семплингом. Для MDMs таких обходных путей не было.

Reflective Masking: три действия на каждом шаге

Reflective Masking вводит три возможных действия для каждой позиции токена на каждом шаге денойзинга:

Reveal — токен достаточно уверен, оставляем как есть. Модель «утверждает» свой текущий прогноз.

Reflectively Mask — модель не уверена в токене, помечает его обратно в MASK и даёт себе ещё одну попытку на следующем шаге. Ключевое отличие от обычного маскирования: решение маскировать принимает сама модель, а не архитектура.

Reserve — токен «замораживается» до следующего шага без изменения. Это даёт модели время «подумать», не принимая окончательного решения.

Получается диффузионный аналог chain-of-thought, только вместо «писать дальше» модель «переписывает избранное». CoT думает, продолжая. RM думает, решая, что исправить.

History Reference: память без параметров

Второй ключевой компонент — History Reference (HR). Проблема многоходовой ревизии в том, что на шаге T модель должна «помнить», что она предсказывала на шагах T-1, T-2, и так далее. Если модель на шаге T-1 уже перегенерировала токен X, а на шаге T она снова видит X в контексте — она не должна пытаться его снова.

History Reference решает это без дополнительных параметров. На каждом шаге денойзинга состояния модели накапливаются и прокладываются к следующему шагу через History Embedding Rotation (HER) — механизм, который «прокручивает» эмбеддинги истории так, что модель всегда видит свой полный трек денойзинга. По сути это parameter-free attention over denoising trajectory: вместо того чтобы хранить все предыдущие hidden states вKV-кеше, модель получает компактное представление истории через ротацию эмбеддингов. Это критично для масштабирования — длинные траектории ревизии не переполняют контекст.

По сути, это параметр-свободная память denoising trajectory. Модель видит, что она уже пыталась предсказать, что уже исправляла, и не повторяет одни и те же ошибки.

Sudoku: от 82% до 93% точности

Первая задача — структурированная коррекция ошибок на судоку. Это идеальный бенчмарк для тестирования ревизии: модель получает сетку 9×9 с 4–20 испорченными ячейками и должна итеративно исправлять ошибки.

Используется крошечная MDM — всего 0.81M параметров, обученная с нуля. Никакого предобученного диффузионного LM, никакой дополнительной архитектуры.

Результаты на валидации по 1000 играм:

Метод Exact Accuracy Valid Rate Replay Mistake % Conflict Cells/board
RM (no HR) 82.4% 86.6% 0.57% 0.578
RM + HR 91.4% (+9.0) 91.8% (+5.2) 0.07% (-0.50) 0.300 (-0.278)
RM + HR + decay 89.4% (+7.0) 89.6% (+3.0) 0.07% (-0.50) 0.362 (-0.216)
RM + HR + decay + HER 93.4% (+11.0) 93.6% (+7.0) 0.03% (-0.54) 0.236 (-0.342)

HR резко сокращает повторяющиеся ошибки — модель перестаёт наступать на одни и те же грабли. HER (History Embedding Rotation) добавляет ещё несколько процентных пунктов, позволяя истории «прокручиваться» через эмбеддинговое пространство.

Сравнение с DiffusionGemma

Интересно сравнить с результатами Google DiffusionGemma на том же судоку. По данным model card DiffusionGemma, точность решения судоку вырастает с 18% при one-shot до 89.5% чистой ревизей за несколько шагов, и с 1.5% до 89.5% после fine-tuning на 4000 шагов.

Reflective Masking достигает 93.4% exact accuracy с 0.81M параметров, обученной с нуля. Это даже выше, чем fine-tuned DiffusionGemma — при том, что модель в ~1000 раз меньше и не требует предобученного диффузионного LM.

Но ключевое отличие: DiffusionGemma работает только с текстом. Reflective Masking одинаково применяется к Sudoku, text reasoning и image editing — модальность, которую DiffusionGemma не поддержирует.

Image editing: локальная ревизия с инструкцией

С 7B multimodal backbone (Lumina-DiMOO) Reflective Masking локализует область редактирования и меняет только её, оставляя остальное изображение нетронутым. Это принципиально отличается от генерации с нуля: модель получает исходное изображение и инструкцию (например, «сделай деревья слева оранжевыми»), находит нужную область и ревизурует только пиксели внутри неё.

Результат — инструкция выполняется точно, а не приблизительно. Baseline-методы (Lumina, Lumina SFT) склонны «интерпретировать» инструкцию слишком широко или пропускать детали. Reflective Masking thanks to selective re-masking работает локально и точно.

Text reasoning: MATH500, MBPP, ARC-Challenge

На открытых задачах — математика и код — RM также показывает улучшения. Бэкбон — LLaDA, диффузионная модель без предобучения на code/math.

Benchmark Base LLaDA Vanilla SFT RM (Ours) Δ vs SFT
MATH500 19.4% 22.4% 24.8% +2.4
MBPP (code) 28.0% 30.6% 39.4% +8.8
ARC-Challenge 73.7% 81.3% 86.1% +4.8

Самый впечатляющий результат — MBPP, где RM даёт +8.8 п.п. против SFT. Это говорит о том, что способность «откатить» неверный код и попробовать заново особенно ценна в задачах генерации, где ошибка в одном месте может сломать всю программу.

На Minerva MATH (по предметным категориям) RM улучшает почти каждую рубрику: от +0.59% в алгебре до +4.26% в теории чисел. Сильнее всего выигрыш на категориях, где есть промежуточные шаги, которые можно проверить и исправить.

Как работает обучение

RM обучается офлайн, без online rollouts. Из чистого target сэмплируется маска, делается один forward pass MDM, и из распределения модели сэмплируются правдоподобные неправильные токены — так строится псевдо-траектория, соответствующая распределению модели. Это важно: в отличие от supervised fine-tuning, где нужны ground-truth правильные траектории, RM генерирует свои собственные «неправильные» примеры и учится на них — какие токены стоит ревизовать, а какие оставлять.

Три per-token loss учат модель трем действиям. Reveal loss: если модель предсказала правильный токен — reward за утверждение. RM loss: если модель предсказала неправильный видимый токен — penalty за отсутствие маскирования, reward за маскирование. Keep loss: если токен правильный и видимый — penalty за ненужное маскирование.

Почему офлайн обучение критично? Online rollouts (когда модель реально генерирует, получает reward, и обновляется) требуют десятки тысяч взаимодействий и значительных вычислительных затрат. Офлайн подход с псевдо-траекториями даёт тот же сигнал за долю времени. История хранится в виде эмбеддингов на каждом шаге — это и есть то, что передаётся в следующий forward pass.

Вся процедура обучения укладывается в ~5 часов на 2×H100. Это существенно дешевле, чем fine-tuning DiffusionGemma на 4000 шагов, который требует значительно больших вычислительных ресурсов.

Почему это важно

Reflective Masking решает фундаментальную проблему MDMs: невозможность выборочной ревизии. До этого диффузионные модели могли редактировать весь canvas, но не могли выбрать, что именно переделать. Теперь маскирование становится решением модели — «доверяю ли я этому токену достаточно, чтобы утвердить его?»

Для практических применений это означает, что диффузионные модели смогут автономно исправлять свои ошибки в задачах программирования, математики, логического вывода — без человеческого вмешательства и без перегенерации с нуля.

Диффузионные языковые модели больше не обречены генерировать текст вслепую. Теперь они могут рассуждать, исправляя себя.

Часто задаваемые вопросы

Чем Reflective Masking отличается от chain-of-thought?

Chain-of-thought в авторегрессионных моделях — это «писать дальше». Модель добавляет новые токены рассуждений, старые ошибки остаются в контексте. Reflective Masking — это «переписывать избранное». Модель стирает неуверенные токены и заменяет их, так что ошибки не накапливаются.

Почему History Reference — это прорыв?

History Reference решает проблему «забывания» в многоходовой ревизии. Без него модель на шаге T не знает, что она предсказывала на шаге T-1. С ним — видит полный трек и не повторяет ошибки. При этом механизм не добавляет параметров: это просто накопление состояний с поворотом эмбеддингов.

Можно ли использовать RM с любой MDM?

Да. Reflective Masking — это post-training метод, который применяется к существующим MDMs без изменения архитектуры. Эксперименты проведены на LLaDA и Lumina-DiMOO, но процедура универсальна.

Итог

Reflective Masking превращает масочные диффузионные модели в многоходовых ревизоров. Модель сама решает, какие токены перегенерировать, на каждом шаге денойзинга. History Reference даёт ей память о прошлых попытках — без параметров и без дополнительных attention sequences.

Результат: 93.4% на судоку с 0.81M параметров (vs 89.5% у fine-tuned DiffusionGemma), улучшения на MATH500, MBPP и ARC-Challenge, и применимость к image editing — модальность, недоступная текстовым диффузионным моделям.

CoT думает, продолжая. RM думает, исправляя.

← Все записи