Исследователи из Университета Стэнфорда представляют FlashFFTConv новую систему искусственного интеллекта для оптимизации преобразований FFT для длинных последовательностей.

Новая система искусственного интеллекта FlashFFTConv для оптимизации преобразований FFT длинных последовательностей исследование Университета Стэнфорда

Рассуждение эффективно на протяжении длинных последовательностей является основной трудностью в машинном обучении. Недавно свертки приобрели значение в качестве основного примитива для моделирования последовательностей, обеспечивая современную производительность в языковом моделировании, анализе временных рядов, компьютерном зрении, моделировании ДНК и других областях. Несмотря на эти впечатляющие результаты и дополнительные преимущества, такие как улучшенная стабильность и лучшая масштабируемость с увеличением длины последовательности, сверточные модели последовательностей все еще значительно медленнее, чем трансформеры.

Одна из основных причин – ненадежная аппаратная поддержка. Свертки для моделирования последовательностей часто используют фильтры, размер которых соответствует длине входной последовательности, в отличие от коротких фильтров, используемых в классических свертках для визуальных приложений. Алгоритм быстрого преобразования Фурье (FFT) для свертки вычисляет свертку между входом u и ядром свертки k, отображая входные и выходные частоты.

Несмотря на алгоритмическую эффективность, алгоритм быстрого преобразования Фурье имеет низкое время работы на современных ускорителях. Однако технологический прогресс в системах позволил трансформерам достичь пределов текущих ускорителей, с использованием до 72% общего количества операций FLOP при использовании FlashAttention-v2.

Для создания возможности работы с длинными последовательностями новое исследование из Стэнфордского университета исследует оптимизацию сверточного алгоритма быстрого преобразования Фурье на современных ускорителях. Исследователи считают, что как развитие систем, таких как FlashAttention привело к более хорошим моделям и новым алгоритмам внимания, оптимизация сверточного алгоритма быстрого преобразования Фурье приведет к новым и лучшим алгоритмам, повышающим качество сверточных моделей последовательности.

Алгоритм быстрого преобразования Фурье может быть легко оптимизирован для коротких последовательностей. Общепринятой практикой является повторное использование ядерных фильтров в нескольких пакетах, что позволяет предварительно вычислить преобразование Фурье фильтра перед его повторным использованием. Таким образом, сверточное преобразование Фурье осуществляется параллельно по пакетам и фильтрам, а слияние ядер позволяет кэшировать промежуточные выходы свертки в SRAM или регистрах.

  1. Однако команда отмечает, что с ростом длины последовательности возникают две основные проблемы. В отношении современных ускорителей свертки Фурье не оптимально используют специализированные блоки умножения матриц.
  2. Во-вторых, с увеличением длины последовательности слияние ядер не работает, так как последовательности становятся слишком длинными для помещения в SRAM, и требуются затратные операции ввода-вывода. Операции дополнения для внесения причинности и преобразования из реального входа/вывода в комплексные промежуточные значения FFT могут еще больше увеличить эти затраты ввода-вывода.

В ответ на это исследователи предлагают FlashFFTConv, новый алгоритм, который использует декомпозицию Монарха Фурье для оптимизации сверточного алгоритма быстрого преобразования Фурье для длинных последовательностей. Благодаря декомпозиции Монарха Фурье порядка p, FFT эффективно может передаваться на аппаратуре, переписывая Фурье в виде серии операций умножения матриц порядка p. Более высокие значения p обеспечивают меньшие затраты FLOP из-за более маленьких матриц, но требуют большего количества ввода-вывода для передачи промежуточных результатов. Следовательно, здесь есть компромисс.

Исследование демонстрирует, как оптимизировать p для затрат FLOP и затраты ввода-вывода на GPU с использованием простой модели затрат, основанной на длине последовательности. В дополнение к облегчению слияния ядер сверточных моделей при большей длине последовательности, такая декомпозиция также снижает объем последовательности, который должен быть сохранен в SRAM. Таким образом, FlashFFTConv может легко обрабатывать последовательности от 256 до 4 миллионов символов. Используя алгоритм быстрого преобразования Фурье для вещественных значений и пропуская части операций умножения матриц при нулевом заполнении входа, FlashFFTConv может уменьшить длину операции FFT в два раза. Наконец, матричный подход к сверточному преобразованию Фурье предоставляет простой интерфейс для реализации двух архитектурных модификаций: частичные свертки, которые обучаются с использованием ядра свертки, которое короче входной последовательности, и разреженные свертки по частоте, которые обнуляют определенные секции ядра в частотном пространстве. Оба подхода могут быть легко реализованы путем исключения частей матричной декомпозиции, что снижает объем памяти и время выполнения, и могут быть рассмотрены как сверточные аналоги разреженного/приближенного внимания в трансформерах.

Исследователи демонстрируют, что FlashFFTConv ускоряет сверточное преобразование Фурье, что приводит к более высокому качеству, более эффективным и более длинным моделям последовательности.

  • FlashFFTConv улучшает качество сверточных моделей последовательности за счет лучшей эффективности: при равном вычислительном бюджете FlashFFTConv позволяет Hyena-GPT-s достичь перплексии на 2,3 пункта лучше и позволяет M2-BERT-base достичь до 3,3 выше среднего значения GLUE-показателя – улучшение производительности, эквивалентное удвоению параметров модели.
  • FlashFFTConv улучшает эффективность сверток до 7,93 и до 5,60 в экономии памяти по сравнению с PyTorch, и эта эффективность сохраняется на протяжении четырех порядков длины последовательности. FlashFFTConv быстрее в терминах времени работы, чем FlashAttention-v2 на длинах последовательностей 2K и более благодаря меньшим затратам FLOP и достигает до 62,3% использования FLOP с начала и до конца, что на 10% меньше, чем FlashAttention-v2.
  • С FlashFFTConv возможны модели более длинных последовательностей. FlashFFTConv произвела единственную модель, способную выполнить задачу Path-512 в длинном бенчмарке арены (длина последовательности 256K) для классификации изображений высокого разрешения. FlashFFTConv – это первая модель, способная вставлять самые длинные гены человека (до 2,3 млн нуклеотидных пар) с разрешением отдельного нуклеотида; она расширяет HyenaDNA до длины последовательности 4M с помощью частичных сверток.

Команда надеется, что FlashFFTConv проложит путь к более широкому использованию сверточных моделей последовательностей, и что полученные уроки приведут к созданию более ресурсоэффективных компьютерных архитектур.