Понимание Flash-Attention и Flash-Attention-2 Путь к увеличению длины контекста языковых моделей
Understanding Flash-Attention and Flash-Attention-2 Path to increasing the context length of language models
Два метода обеспечивают значительные улучшения при обработке более длинных текстовых последовательностей в LLMs.

Недавно я начал информационную рассылку, посвященную искусственному интеллекту, которая уже имеет более 160 000 подписчиков. TheSequence – это информационная рассылка, ориентированная на машинное обучение, которая занимает 5 минут для чтения. Цель – держать вас в курсе машинных проектов, научных статей и концепций. Пожалуйста, попробуйте подписаться ниже:
TheSequence | Jesus Rodriguez | Substack
Лучший источник, чтобы быть в курсе новостей из мира машинного обучения, искусственного интеллекта и данных…
thesequence.substack.com
Масштабирование контекста больших языковых моделей (LLMs) остается одной из самых больших проблем для расширения сферы применения. В последние месяцы мы видели, как компании, такие как Anthropic или OpenAI, увеличивают длину контекста своих моделей до новых высот. Этот тренд, скорее всего, будет продолжаться, но, вероятно, потребуются некоторые научные открытия. Одной из самых интересных работ в этой области была недавно опубликованная работа Стэнфордского университета. Под названием FlashAttention, эта новая техника быстро стала одним из основных механизмов увеличения контекста LLMs. Вторая итерация FlashAttention, FlashAttention-2, была недавно опубликована. В этой статье я хотел бы рассмотреть основы обеих версий.
FashAttention v1
FlashAttention является новаторским алгоритмом в сфере передовых алгоритмов. Этот алгоритм не только переупорядочивает вычисление внимания, но также использует классические техники, такие как тайлинг и повторное вычисление, чтобы достичь замечательного ускорения и существенного снижения использования памяти. Это преобразование позволяет перейти от квадратичного к линейному объему памяти относительно длины последовательности. В большинстве случаев FlashAttention работает неплохо, но есть одно ограничение – он не был настроен для особенно длинных последовательностей, где отсутствует параллелизм.
- Топ важных научных статей по компьютерному зрению на неделю с 4/9 по 10/9
- Автоматическая генерация музыки с использованием глубокого обучения
- От нуля до героя Создайте свою первую модель машинного обучения с помощью PyTorch
При решении задачи обучения больших трансформеров на расширенных последовательностях ключевую роль играют современные техники параллелизма, такие как параллелизм данных, параллелизм конвейеров и параллелизм тензоров. Эти подходы разделяют данные и модели между несколькими GPU, что может привести к крайне малым размерам пакета (например, размер пакета 1 с параллелизмом конвейеров) и скромному количеству головок, обычно в диапазоне от 8 до 12 с параллелизмом тензоров. Именно такую ситуацию FlashAttention и старается оптимизировать.
Для каждой головки внимания FlashAttention использует классические техники тайлинга, чтобы минимизировать чтение и запись памяти. Он перемещает блоки запросов, ключей и значений из основной памяти GPU (HBM) в его быструю кэш-память SRAM. После выполнения вычислений внимания с этим блоком он записывает результат обратно в HBM. Это снижение чтения/записи памяти приводит к существенному ускорению, часто в 2-4 раза по сравнению с исходной скоростью в большинстве случаев.
Первая версия FlashAttention использовала параллелизм по размеру пакета и количеству головок. Те, кто знаком с программированием на CUDA, оценят использование одного блока потоков для обработки каждой головки внимания, что дает в общей сложности batch_size * num_heads блоков потоков. Каждый блок потоков тщательно планируется для выполнения на мультипроцессоре потокового процессора (SM), с A100 GPU, которая имеет 108 этих SM. Это планирование действительно проявляется, когда batch_size * num_heads достигает значительных значений, скажем, больше или равно 80. В таких случаях оно позволяет эффективно использовать практически все вычислительные ресурсы GPU.

Однако, когда речь идет о работе с длинными последовательностями, обычно связанными с малым размером пакета или ограниченным количеством головок, FlashAttention использует другой подход. Теперь он вводит параллелизм по размеру последовательности, что приводит к замечательному ускорению, специально адаптированному для этой конкретной области.
Когда речь идет о обратном проходе, FlashAttention выбирает несколько измененную стратегию параллелизации. Каждый рабочий берет на себя блок столбцов в матрице внимания. Эти рабочие сотрудничают и обмениваются информацией, чтобы агрегировать градиенты, касающиеся запроса, используя атомарные операции. Интересно, что FlashAttention обнаружил, что параллелизация по столбцам работает лучше, чем параллелизация по строкам в этом контексте. Уменьшенное взаимодействие между рабочими оказывается ключевым, поскольку параллелизация по столбцам подразумевает агрегацию градиента запроса, в то время как параллелизация по строкам требует агрегации градиента ключа и значения.

FlashAttention-2
С FlashAttention-2 команда из Стэнфорда внесла тщательные улучшения в первую версию, сосредоточившись на минимизации не матричных операций с плавающей запятой внутри алгоритма. Это изменение имеет глубокое значение в эпоху современных графических процессоров, которые оснащены специализированными вычислительными блоками, такими как Tensor Cores от Nvidia, что существенно ускоряет умножение матриц (matmul).
FlashAttention-2 также пересматривает используемую им онлайн-технику softmax. Цель заключается в упрощении операций масштабирования, проверки границ и причинного маскирования, сохраняя при этом целостность вывода.
В своей первоначальной версии FlashAttention использовал параллелизм как по размеру пакета, так и по количеству голов. Здесь каждая голова внимания обрабатывалась отдельным блоком потоков, что давало в общей сложности (размер пакета * количество голов) блоков потоков. Эти блоки потоков эффективно планировались на мультипроцессоры передачи данных (SM), и у A100 GPU было 108 таких SM. Эта стратегия планирования оказалась наиболее эффективной, когда общее количество блоков потоков было значительным, обычно превышающим 80, так как это позволяло оптимально использовать вычислительные ресурсы GPU.
Для улучшения работы в сценариях с длинными последовательностями, часто сопровождающимися небольшими размерами пакета или ограниченным числом голов, FlashAttention-2 вводит дополнительное измерение параллелизма – параллелизм по длине последовательности. Эта стратегическая адаптация приводит к существенному увеличению скорости в этом конкретном контексте.
Даже внутри каждого блока потоков FlashAttention-2 должен правильно распределить рабочую нагрузку между различными волнами, представляющими группы из 32 потоков, работающих вместе. Обычно используется 4 или 8 волн на блок потоков, и схема разбиения разъясняется ниже. В FlashAttention-2 этот метод разбиения уточняется с целью уменьшения синхронизации и коммуникации между различными волнами, тем самым минимизируя операции чтения и записи общей памяти.

В предыдущей конфигурации FlashAttention разделял K и V на 4 волны, сохраняя доступность Q для всех волн, называемую схемой “sliced-K”. Однако такой подход показал неэффективность, так как все волны должны были записывать свои промежуточные результаты в общую память, синхронизироваться, а затем агрегировать эти результаты. Эти операции с общей памятью становились узким местом производительности FlashAttention в прямом проходе.
В FlashAttention-2 стратегия принимает другой ход. Теперь Q распределяется по 4 волнам, при этом K и V остаются доступными для всех волн. После того, как каждая волна выполняет умножение матрицы для получения среза Q K^T, они просто умножают его на общий срез V, чтобы получить свой собственный срез вывода. Такое расположение позволяет избежать взаимодействия между волнами. Уменьшение операций чтения/записи общей памяти приводит к значительному увеличению скорости.
Ранняя версия FlashAttention поддерживала размеры головок до 128, что достаточно для большинства моделей, но некоторые оставались в стороне. FlashAttention-2 расширяет поддержку для размеров головок до 256, что позволяет использовать модели, такие как GPT-J, CodeGen, CodeGen2 и StableDiffusion 1.x. Теперь эти модели могут использовать FlashAttention-2 для повышения скорости и эффективности памяти.
Кроме того, FlashAttention-2 вводит поддержку множественного запроса внимания (MQA) и группового запроса внимания (GQA). Это специализированные варианты внимания, при которых несколько головок запроса одновременно обращаются к одной головке ключа и значения. Эта стратегическая маневренность направлена на сокращение размера кэша KV во время вывода, что в конечном итоге приводит к значительно более высокой скорости вывода.
Улучшения
Команда из Стэнфорда оценила FlashAttention-2 на разных тестах с заметными улучшениями по сравнению с оригинальной версией и другими альтернативами. Тесты включали различные вариации архитектуры внимания, и результаты были довольно заметными.

FlashAttention и FlashAttention-2 являются двумя фундаментальными техниками, используемыми для масштабирования контекста LLMs. Это исследование представляет собой одно из самых значительных научных прорывов в этой области и оказывает влияние на новые методы, способствующие увеличению емкости LLMs.