Дистилляция табличных моделей: как сжать ИИ-врача в 26 раз без потери точности
Больница хочет предсказать риск повторной госпитализации по табличным данным электронной карты — возраст, анализы, диагнозы. TabPFNv2.6 даёт AUC 0.870, но для каждого пациента требует запуска нейросети на GPU с контекстной выборкой обучающих примеров. Аппаратное обеспечение стоит десятки тысяч долларов, латентность — полсекунды, а в реанимации каждая миллисекунда на счету. Что если от 90% этой точности можно получить, запуская обычный LightGBM на CPU за 7 мс?
Что такое табличные foundation-модели
Табличные foundation-модели (TFM) — это нейросети, которые обучаются на огромном корпусе табличных датасетов и затем адаптируются к новой задаче без дообучения весов, просто подавая обучающие примеры в контекст. TabPFN, TabICL, TabDPT, LimiX и Orion-MSP показывают на малых выборках (до 10 тысяч строк) результаты, сопоставимые или лучше тщательно настроенного XGBoost — и делают это из коробки, без перебора гиперпараметров.
Проблема в том, как они это делают. TFM на инференсе условно запускают forward-pass через большую трансформерную сеть, причём обучающие данные становятся частью входа. Это означает GPU, высокое энергопотребление, сложность интеграции в существующие конвейеры и невозможность работы в air-gapped средах больниц с жёсткими требованиями к хранению данных.
Почему простое копирование не работает
Первое, что приходит в голову — distillation: взять предсказания TFM, обучить на них компактную модель. Но TFM устроены иначе, чем обычные teacher-модели. Они условятся на обучающей выборке прямо во время инференса, поэтому если скормить teacher'у те же данные, на которых потом обучается student, происходит утечка контекста — teacher «видит» правильные ответы и выдаёт завышенно уверенные soft labels. Результат: student переобучается, а на реальных данных точность падает.
Авторы работы решают это через stratified out-of-fold labeling. Данные разбиваются на K фолдов. Для каждого фолда teacher обучается на остальных K−1 фолдах и генерирует soft labels только для «своего» фолда. Так teacher никогда не видит таргетов тех примеров, для которых выдаёт предсказания, и утечка контекста исчезает.
Как устроена дистилляция под капотом
Student обучается на смеси двух потерь: KL-дивергенции между soft labels teacher'а и предсказаниями student'а, а также кросс-энтропии с истинными hard labels. Вес α=0.7 отдаёт приоритет soft labels. Ключевой трюк — адаптивная температура: для каждого примера teacher вычисляет энтропию своего предсказания, и уверенные примеры получают низкую температуру (T≈1), а неуверенные — высокую (T≈5). Это позволяет student'у лучше учиться на «сложных» примерах, где teacher сам сомневается. Для tree-based student'ов soft loss сводится к взвешенному MSE на логитах soft labels, что делает обучение стабильным и быстрым.
Дополнительно вводится confidence weighting: примеры с промежуточной энтропией получают больший вес, чем крайне уверенные или крайне неуверенные. Это стабилизирует обучение и предотвращает доминирование лёгких примеров. Вместе эти техники — out-of-fold labeling, adaptive temperature и confidence weighting — образуют метод, который authors называют leakage-aware distillation: дистилляция, защищённая от утечки контекста.
Числа: что получилось на 19 медицинских датасетах
Эксперименты проводились на 19 датасетах из кардиологии, онкологии, нефрологии и реаниматологии — от 299 до 9105 строк, от 4 до 40 признаков. Рассматривались 6 TFM-teacher'ов (TabPFNv2.5, TabPFNv2.6, TabICLv2, TabDPT, LimiX, Orion-MSP v1.5) и 4 семейства student'ов: LightGBM, CatBoost, XGBoost и MLP. Каждый результат усреднён по 5 прогонам с фиксированным препроцессингом через библиотеку TabTune, чтобы исключить влияние разных пайплайнов обработки данных.
Результаты по точности (AUC) оказались student-зависимыми. LightGBM удерживает от 98.8% до 99.6% AUC teacher'а, причём в некоторых случаях distilled student превосходит teacher — мягкие метки работают как регуляризатор. CatBoost и XGBoost показывают схожую картину, иногда достигая 100% retention. MLP же отстаёт: 89.6–90.8% retention, и при этом остаётся ниже baseline LightGBM без дистилляции.
Почему деревья побеждают? Табличные данные в здравоохранении гетерогенны: возраст в годах, уровень глюкозы в ммоль/л, бинарные флаги диагнозов — всё вперемешку. Деревья естественно справляются с разными масштабами без предобработки, а MLP требует тщательной регуляризации, которой на малых выборках недостаточно. На трёх датасетах с менее чем 750 примерами и 4–8 признаками MLP показал высокую вариативность между запусками и иногда коллапсировал к случайным предсказаниям, несмотря на встроенный детектор коллапса с перезапуском при повышенном dropout.
Скорость и память: от GPU к CPU
Самый быстрый TFM-teacher, TabICLv2, даёт латентность 187 мс на GPU и пропускную способность 3221 предсказания в секунду. Самый медленный, TabDPT, требует 4 секунды на одно предсказание. Distilled LightGBM на CPU работает за 7 мс — в 26 раз быстрее TabICLv2 и в 570 раз быстрее TabDPT. MLP-student достигает 3.7–3.8 мс, что даёт ускорение до 49 раз относительно TabICLv2 и пропускную способность до 172 тысяч предсказаний в секунду на одном CPU-ядре.
Это меняет экономику развёртывания. Batch-скоринг миллионов записей, потоковая обработка данных с мониторов ICU, интеграция в существующие EHR-системы — всё это становится возможным без покупки GPU-кластеров и без выхода данных за периметр больницы.
Калибровка и справедливость: не только точность
В медицине точность — не единственный критерий. Врачу нужны откалиброванные вероятности: если модель говорит «риск 80%», то примерно в 8 случаях из 10 пациент действительно должен иметь негативный исход. И модель не должна систематически занижать риск для одних демографических групп и завышать для других. Модель с ECE 0.090, как baseline LightGBM, систематически переоценивает свою уверенность — врач, полагающийся на такие вероятности, принимает неверные решения о госпитализации или назначении терапии.
Distilled LightGBM сохраняет калибровку teacher'ов: ECE (Expected Calibration Error) около 0.058–0.063, что существенно лучше baseline LightGBM (0.090). Temperature scaling даёт лишь небольшой дополнительный прирост, то есть distilled student уже хорошо откалиброван из коробки. MLP же показывает ECE выше 0.12 до temperature scaling и требует пост-обработки. Интересно, что baseline GBDT-модели (XGBoost, LightGBM) показывают наибольшее абсолютное снижение ECE от temperature scaling (0.090 → 0.069), что согласуется с известной склонностью деревьев к переуверенности в листьях. Любое развёртывание, использующее сырые GBDT-оценки как вероятности риска, нуждается в калибровке.
По справедливости картина двойственная. Distilled LightGBM уменьшает demographic parity gap в среднем на 0.013 и equality of opportunity gap на 0.014 — soft labels в меньшем гипотезном пространстве работают как регуляризатор справедливости. MLP уменьшает demographic parity сильнее (на 0.034), но при этом увеличивает equality of opportunity gap на 0.041. Для практики это означает: если нужна стабильная справедливость по обоим метрикам, выбирайте LightGBM. При этом authors предупреждают, что EO-регрессия у MLP — реальная проблема, и аудит справедливости нужно проводить на реальных чувствительных атрибутах конкретного развёртывания, а не экстраполировать из таблицы.
Мульти-teacher: усреднение не помогает
Логично было бы предположить, что усреднение soft labels от нескольких TFM даст student'у больше информации, чем один teacher. Эксперименты показывают обратное: лучший multi-teacher LGBM не превосходит лучшего single-teacher даже на третьем знаке после запятой. Причина в том, что TFM-teacher'ы близки по точности (AUC 0.864–0.875), поэтому потенциал ансамбля мал. Равновесное усреднение при этом трактует расхождения teacher'ов как полезную информацию, даже когда они ближе к шуму. Взвешенная по точности схема могла бы дать прирост, но это оставлено для будущих работ.
Ограничения и когда метод не сработает
MLP-student нестабилен на малых низкоразмерных датасетах (менее 750 примеров, 4–8 признаков). На трёх таких датасетах он показал высокую вариативность между запусками и иногда коллапсировал к случайным предсказаниям. Рекомендация авторов: использовать LGBM по умолчанию, а MLP — только когда латентность менее 2 мс является жёстким требованием.
Мультиклассовая дистилляция пока экспериментальна. На датасете с тремя классами TabPFNv2.5 даёт AUC 0.985, а distilled LGBM — всего 0.749. Проблема в том, что soft-label distillation с LGBM-регрессором ломается при более чем двух классах. Правильный фикс — мультиномиальная формулировка потерь, но она ещё не реализована.
Часто задаваемые вопросы
Почему не использовать сразу XGBoost или LightGBM без дистилляции?
Baseline LightGBM на этих датасетах даёт AUC 0.853 — хуже любого TFM. Дистилляция переносит в student не только предсказания, но и регуляризацию через мягкие метки, что позволяет distilled LGBM достигать 0.862–0.865 и иногда превосходить teacher. Это бесплатный прирост точности без перебора гиперпараметров.
Нужен ли GPU для обучения distilled модели?
Для генерации soft labels teacher'ом — да, GPU нужен. Но это одноразовый офлайн-этап. После того как distilled student обучен, он работает на CPU и не требует ни GPU, ни облачной инфраструктуры. Для больницы это означает: заплатить за GPU-время один раз при обучении, а затем работать на существующих серверах.
Сохраняется ли приватность данных при дистилляции?
Да, приватность сохраняется лучше, чем при прямом использовании TFM. Teacher обрабатывает данные внутри периметра, генерирует soft labels, и только эти вероятности (а не сырые данные) используются для обучения student'а. При этом authors отмечают риск bias inheritance: если teacher содержит демографические искажения, student их унаследует. Рекомендуется аудит справедливости на реальных чувствительных атрибутах конкретного развёртывания.
Итог
Работа показывает, что высокая стоимость инференса TFM — это не фундаментальное ограничение их точности, а следствие способа инференса: контекстно-зависимых GPU forward-pass'ов. Out-of-fold дистилляция переносит эту стоимость с этапа эксплуатации на этап обучения, сохраняя 90–99% точности, калибровку и справедливость. Для медицинских систем, где критичны латентность, стоимость инфраструктуры и приватность данных, это открывает практический путь от дорогих foundation-моделей к лёгким моделям, которые работают на любом CPU. Практическая рекомендация авторов: использовать LGBM-student по умолчанию, а MLP — только когда sub-2 мс латентность является жёстким требованием.