Трансформеры оказали революционное влияние на развитие искусственного интеллекта, предложив новый подход к обработке последовательностей через механизм многоголового внимания. Принцип действия трансформера базируется на способности одновременно учитывать информацию из разных позиций входной последовательности, что обеспечивает значительный рост точности и гибкости моделей. Однако при всей своей мощности стандартный механизм внимания сталкивается с серьезной проблемой – его вычислительная и памятьемкая сложность растут квадратично с увеличением длины последовательности. То есть удвоение длины входных данных приводит к увеличению требуемого объема ресурсов в четыре раза, что существенно ограничивает возможности обработки длинных документов, высокоразрешенных изображений и прочих объемных данных. Для решения этой задачи исследователи применяли различные техники, в том числе разреженное и низкоранговое внимание, которые позволяли уменьшить нагрузку, но зачастую за счет потери точности и ограниченных приростов в скорости.
Революционным шагом стало появление Flash Attention в 2022 году, разработанного Три Дао и его командой. В отличие от наивного подхода, при котором создается огромная матрица оценок внимания, Flash Attention использует стратегию, оптимизированную с точки зрения ввода-вывода. Вместо того чтобы постоянно записывать и читать большие объемы данных из основной памяти GPU, алгоритм структурирует вычисления так, чтобы минимизировать операции чтения/записи и максимально задействовать быструю SRAM-кэш-память на чипе. Это достигается за счет разбиения вычислений на плитки и слияния нескольких операций в единый шаг. В результате Flash Attention обеспечивает ускорение в 2-4 раза и сокращение использования памяти в 10-20 раз по сравнению с традиционным подходом, при этом сохраняя полную точность и не прибегая к приближениям.
Данные преимущества не только снизили время обучения, но и позволили трансформерам работать с ранее непосильной длиной контекста, достигающей 16 тысяч и даже 64 тысяч токенов. Например, в бенчмарке Path-X, основанном на 16-тысячном входе, стандартные модели трансформеров показывали лишь случайную точность, тогда как благодаря Flash Attention удалось преодолеть этот порог и добиться значительного улучшения результата. С тех пор прогресс не остановился: появились преемники – Flash Attention 2 и Flash Attention 3, оптимизированные для новых архитектур GPU, таких как NVIDIA Hopper. Вторая версия удвоила скорость по сравнению с оригиналом, открыв новые горизонты для масштабирования моделей с миллионами токенов контекста. С открытым исходным кодом и значимыми преимуществами Flash Attention быстро получил поддержку в экосистемах машинного обучения.
Популярная библиотека PyTorch с версии 2.2 интегрировала нативную поддержку Flash Attention 2, автоматически используя ускоренную операцию scaled_dot_product_attention при работе на GPU NVIDIA. С выпуском PyTorch 2.3 эта функциональность была расширена и на ROCm-бэкенд, тем самым предоставив высокоскоростное и эффективное внимание на оборудовании AMD. Поначалу Flash Attention был реализован в виде CUDA-кернела, что делало его эксклюзивным для NVIDIA.
Однако вскоре он появился и в ROCm-экосистеме AMD. Сегодня на базе AMD Instinct MI300X доступны сразу несколько вариантов реализации вычисления scaled dot product attention. Среди них выделяются ROCm/Flash Attention 2 с Composable Kernel (CK) и Triton-бэкендами, причем Triton одна из наиболее полнофункциональных реализаций, включая поддержку низкой точности FP8. Кроме того, PyTorch 2.6.
0 предлагает FlexAttention, а ROCm TransformerEngine предоставляет собственный механизм внимания через te.DotProductAttention. Все эти варианты позволяют разработчикам выбирать оптимальный баланс скорости, качества и простоты интеграции. Для сравнения эффективности была проведена серия экспериментов на реальной задаче – обучении модели nanoGPT Карпафти на AMD Instinct MI300X. В качестве тестовых условий использовалась длина блока 1024 с батч-сайзом 64, что обеспечивало загрузку памяти устройства и достаточный запас ресурса.
В общей сложности были протестированы семь вариантов механизмов внимания: Flash Attention 2 Triton в базовой конфигурации, с автотьюнингом и в FP8 варианте, Flex Attention из PyTorch, Transformer Engine Triton, традиционный наивный расчет и интегрированная в PyTorch scaled_dot_product_attention (скорее всего с CK-бэкендом). При этом особое внимание уделялось не только производительности, но и влиянию на качество обучения, измеряемому потерями (loss). Результаты показали, что вне зависимости от варианта, большинство решений сохраняли сопоставимый уровень потерь, за исключением TransformerEngine, где наблюдалось заметное ухудшение. Это сразу ограничивает практическую полезность данной реализации. Что касается скорости, наиболее быстрым оказался Flash Attention 2 на Triton с FP8, значительно превосходящий наивный baseline.
Рядовые scaled_dot_product_attention и Flash Attention 2 Triton без FP8 также показывали хорошие результаты, при этом автотьюнинг добавлял преимущество по времени. Flex Attention и Transformer Engine Triton обеспечивали лишь незначительный прирост производительности, при этом последний ухудшал качество модели. Анализ потребления памяти подтвердил ожидаемую схему: наивный подход требует значительно больше видеопамяти, Transformer Engine расходует меньше, но при этом уступает более оптимизированным вариантам, которые демонстрируют примерно одинаковый и более эффективный расход памяти. Значительная экономия в VRAM при использовании Flash Attention важна для тренировки длинных контекстов и крупных моделей. Трассировка активности памяти HBM выявила, что менее эффективные реализации тратят много времени на операции ввода-вывода, тогда как оптимизированные ядра максимально используют кэш и вычислительные ресурсы GPU.
Итогом сравнений стал рейтинг вариантов с учетом производительности, простоты установки и полноты функционала. Flash Attention 2 Triton FP8 был признан лучшим решением благодаря высокой скорости, поддержке различных настроек, включая Generalized Query Attention (GQA) и ALiBi, а также удобству администрирования. Второе место заняли базовые и autotuned версии Flash Attention 2 Triton, которые также демонстрируют надежность и хорошую скорость. Далее следуют стандартные средства PyTorch scaled_dot_product_attention, подходящие для большинства сценариев, но не предоставляющие расширенные возможности по ALiBi. Flex Attention получил негативную оценку из-за минимального прироста и сложности использования, а Transformer Engine Triton – из-за ухудшения качества и слабой производительности.
Без сомнения, Flash Attention на ROCm открывает новые горизонты для обучения и внедрения трансформеров на AMD железе. Выдающаяся оптимизация, экономия памяти и высокая скорость формируют фундамент для дальнейшего масштабирования моделей и работы с экстремально длинными контекстами, что становится особенно актуально в эпоху больших языковых моделей с миллионами токенов. Несмотря на некоторые трудности с документацией и настройками FP8, поддержка AMD и активное сообщество помогают быстро решать возникающие вопросы. В целом, интеграция Flash Attention в ROCm и PyTorch предоставляет разработчикам мощный инструмент для эффективного обучения трансформеров вне зависимости от выбора аппаратного обеспечения, способствуя росту экосистемы открытого и доступного ИИ. Нельзя не отметить, что дальнейшее развитие технологий внимания и их адаптация под разные архитектуры GPU будут ключевыми для достижения новых высот в области искусственного интеллекта и машинного обучения.
Применение Flash Attention и его преемников на AMD платформе – яркий пример того, как инновационные алгоритмы и программные решения могут значительно повысить эффективность и доступность современных ИИ-систем.