Реализация нейронного кодировщика Transformer с нуля с использованием JAX и Haiku 🤖
Создание нейронного кодировщика Transformer с нуля с применением JAX и Haiku 🤖
Понимание фундаментальных строительных блоков трансформеров.
Представленная в 2017 году в статье “Вниманием надо обойтись” [0], архитектура Трансформера, наверное, является одним из самых значимых прорывов в недавней истории глубокого обучения, позволяющим возникновение больших языковых моделей и находящим применение даже в таких областях, как компьютерное зрение.
Последующие лидирующие архитектуры, основанные на рекурренции, такие как сети долгой краткосрочной памяти (LSTM) или блоки с затворами (GRU), Трансформеры вводят понятие само-внимания, совмещенное с архитектурой кодера/декодера.
В этой статье мы пошагово и с нуля реализуем первую половину Трансформера – кодер. Мы будем использовать JAX в качестве основного фреймворка вместе с Haiku, одной из библиотек глубокого обучения DeepMind.
Если вы не знакомы с JAX или вам нужно освежить память о его удивительных функциональностях, я уже рассматривал эту тему в контексте обучения с подкреплением в моей предыдущей статье:
- Руководство по расширенным настройкам ChatGPT – Топ П, Штрафы за частоту, Температура и многое другое
- История открытых LLM-программ начальные годы (Часть первая)
- «Создание Rest API с помощью Go анализ данных для временных рядов»
Векторизация и параллелизация сред RL с помощью JAX: Q-обучение на скорости света⚡
Научитесь векторизировать среду GridWorld и обучать 30 Q-обучаемых агентов параллельно на CPU на 1.8 миллиона шагов в минуту…
towardsdatascience.com
Мы рассмотрим каждый из блоков, составляющих кодер, и научимся эффективно их реализовывать. В частности, контур этой статьи содержит:
- Слой вложения и позиционные кодировки
- Многоголовое внимание
- Соединения остатков и нормализация слоя
- Сети прямой обратной связи по позициям
Отказ от ответственности: эта статья не претендует на полное введение в эти понятия, так как мы сосредоточимся в первую очередь на реализации. При необходимости, пожалуйста, обратитесь к ресурсам в конце этого сообщения.
Как всегда, полностью прокомментированный код для этой статьи, а также иллюстрированные блокноты, доступны на GitHub, не стесняйтесь добавить этот репозиторий в избранное, если вам понравилась статья!
GitHub – RPegoud/jab: Коллекция реализованных на JAX основных моделей глубокого обучения
Коллекция реализованных на JAX основных моделей глубокого обучения – GitHub – RPegoud/jab: Коллекция…
github.com
Основные параметры
Прежде чем мы начнем, нам нужно определить несколько параметров, которые сыграют решающую роль в блоке кодера:
- Длина последовательности (seq_len): количество токенов или слов в последовательности.
- Измерение вложений (embed_dim): размерность вложений, другими словами, количество числовых значений, используемых для описания одного токена или слова.
- Размер пакета (batch_size): размер пакета входных данных, т.е. количество обрабатываемых последовательностей одновременно.
Входные последовательности в нашу модель кодировщика обычно имеют форму (batch_size
, seq_len
). В этой статье мы используем значения batch_size=32
и seq_len=10
, что означает, что наш кодировщик одновременно обрабатывает 32 последовательности из 10 слов.
Уделяя внимание форме наших данных на каждом этапе обработки, мы сможем лучше визуализировать и понять, как данные протекают в блоке кодировщика. Вот краткий обзор нашего кодировщика, начнем снизу с слоя векторного представления и позиционного кодирования:
![Представление блока кодировщика Transformer (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*Mfz2VwpV4_pmBgXlRxVSLw.jpeg)
Слой векторного представления и позиционное кодирование
Как уже упоминалось ранее, наша модель берет пакетные последовательности токенов в качестве входных данных. Генерация этих токенов может быть настолько простой, как сбор набора уникальных слов в наборе данных и присвоение каждому из них индекса. Затем мы бы выбрали 32 последовательности из 10 слов и заменили бы каждое слово его индексом в словаре. Эта процедура даст нам массив формы (batch_size
, seq_len
), как и ожидалось.
Теперь мы готовы приступить к нашему кодировщику. Первый шаг – создать «позиционные вложения» для наших последовательностей. Позиционные вложения – это сумма векторных представлений слов и позиционного кодирования.
Векторные представления слов
Векторные представления слов позволяют нам кодировать значение и семантические связи между словами в нашем словаре. В этой статье размерность вложений фиксирована и составляет 64. Это означает, что каждое слово представлено вектором размерности 64, чтобы слова с похожими значениями имели похожие координаты. Более того, мы можем изменять эти векторы для извлечения связей между словами, как показано ниже.
Используя Haiku, создание обучаемых вложений так же просто, как вызов:
hk.Embed(vocab_size, embed_dim)
Эти вложения будут обновляться вместе с другими обучаемыми параметрами во время обучения модели (подробнее об этом чуть позже).
Позиционное кодирование
В отличие от рекуррентных нейронных сетей, трансформеры не могут вывести позицию токена, исходя из общего скрытого состояния, так как у них отсутствуют рекуррентные или сверточные структуры. Поэтому введены позиционные кодирования, векторы, которые передают позицию токена в входной последовательности.
По сути, каждому токену присваивается позиционный вектор, состоящий из чередующихся значений синуса и косинуса. Эти векторы имеют ту же размерность, что и векторные представления слов, чтобы их можно было сложить.
В частности, в оригинальной статье о Transformer используются следующие функции:
![Функции позиционного кодирования (перепечатка из «Attention is all you need», Васвани и др., 2017)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/0*wUaDUdiDqfE48-EZ.png)
Ниже представлены данные, которые помогают нам лучше понять работу позиционных кодировок. Давайте посмотрим на первую строку верхнего графика, мы видим чередующиеся последовательности нулей и единиц. Действительно, строки представляют позицию токена в последовательности (переменная pos
), а столбцы представляют размерность вложения (переменная i
).
Поэтому, когда pos=0
, предыдущие уравнения возвращают sin(0)=0
для четных размерностей вложения и cos(0)=1
для нечетных размерностей.
Кроме того, мы видим, что соседние строки имеют схожие значения, тогда как первая и последняя строки значительно отличаются. Это свойство полезно для модели для оценки расстояния между словами в последовательности, а также их порядка.
Наконец, третий график представляет сумму позиционных кодировок и вложений, которая является результатом блока вложения.
Представление вложений слов и позиционных кодировок, с seq_len=16 и embed_dim=64 (сделано автором)
Используя Haiku, мы определяем слой вложений следующим образом. Подобно другим фреймворкам глубокого обучения, Haiku позволяет нам определить пользовательские модули (здесь hk.Module
) для хранения обучаемых параметров и определения поведения компонентов нашей модели.
Каждому модулю Haiku нужно иметь функцию __init__
и __call__
. Здесь функция call
просто вычисляет вложения с помощью функции hk.Embed
и позиционных кодировок, а затем суммирует их.
Функция позиционного кодирования использует функции JAX, такие как vmap
и lax.cond
для оптимальной производительности. Если вы не знакомы с этими функциями, не стесняйтесь посмотреть мою предыдущую публикацию, где они подробно рассмотрены.
Простыми словами, vmap
позволяет нам определить функцию для одного образца и векторизовать ее, чтобы она могла быть применена к пакетам данных. Параметр in_axes
используется, чтобы указать, что мы хотим итерироваться по первой оси ввода dim
, которая является размерностью вложения. С другой стороны, lax.cond
– это совместимая с XLA версия оператора if/else в Python.
Внимание на себя и многоголовое внимание
Внимание направлено на вычисление важности каждого слова в последовательности, относительно входного слова. Например, в предложении:
“Черный кот прыгнул на диван, улегся и заснул, так как стал уставшим”.
Слово “он” может быть довольно неоднозначным для модели, так как технически оно может относиться как к “коту”, так и к “дивану”. Хорошо обученная модель внимания смогла бы понять, что “он” относится к “коту“, и соответствующим образом присвоила бы значения внимания остальной части предложения.
По сути, значения внимания можно рассматривать как веса, которые описывают важность определенного слова с учетом контекста входного слова. Например, вектор внимания для слова “прыгнул” будет иметь высокие значения для слов, таких как “кот” (что прыгнуло?), “на” и “диван” (куда прыгнуло?) , поскольку эти слова являются существенными для его контекста.
![Визуальное представление вектора внимания (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*ySDEvRPl9WfONWRVOey44Q.png)
В статье Transformer внимание вычисляется с использованием Dimensionality Scaled Dot-Product Attention, что можно представить формулой:
![Dimensionality Scaled Dot-Product Attention (воспроизведено из «Внимание - все, что вам нужно», Васвани и др., 2017)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/0*QEogF8q4SRV4uvl7.png)
Здесь Q, K и V обозначают Запросы, Ключи и Значения. Эти матрицы получаются путем умножения векторов весов WQ, WK и WV на позиционные вложения.
Эти имена в основном являются абстракциями, используемыми для понимания того, как информация обрабатывается и взвешивается в блоке внимания. Они отсылают к словарю систем извлечения (например, поиск видео на YouTube).
Вот интуитивное объяснение:
- Запросы: Их можно интерпретировать как “набор вопросов” о всех позициях в последовательности. Например, исследование контекста слова и попытка определить наиболее связанные части последовательности.
- Ключи: Их можно рассматривать как информацию, с которой взаимодействуют запросы. Совместимость между запросом и ключом определяет, насколько важно запросу обратить внимание на соответствующее значение.
- Значения: Сопоставление ключей и запросов позволяет нам решить, какие ключи являются значимыми, а значения – фактическим содержанием, сопоставленным с ключами.
На следующей иллюстрации запрос представляет собой поиск на YouTube, ключи – описания видео и метаданные, а значения – соответствующие видео.
![Интуитивное изображение концепции Запросы, Ключи, Значения (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*ubT1YkrPprthMq9pKGbroA.jpeg)
В нашем случае запросы, ключи и значения берутся из одного источника (поскольку они получены из входных последовательностей), отсюда и название самовнимание.
Вычисление оценок внимания обычно выполняется параллельно несколько раз, каждый раз с использованием части вложений. Этот механизм называется “Многоголовое внимание” и позволяет каждой голове параллельно изучать несколько различных представлений данных, что приводит к более устойчивой модели.
Одна голова внимания обычно обрабатывает массивы со формой (batch_size, seq_len, d_k
), где d_k
может быть установлено как отношение количества голов и размерности вложений (d_k = n_heads/embed_dim
). Таким образом, объединение результатов каждой головы удобно дает массив, имеющий форму (batch_size, seq_len, embed_dim
) в качестве ввода.
Вычисление матриц внимания можно разбить на несколько шагов:
- Во-первых, определяем обучаемые векторы весов WQ, WK и WV. Эти векторы имеют формы (
n_heads, embed_dim, d_k
). - Параллельно с этим, мы умножаем позиционные вложения на векторы весов. Мы получаем матрицы Q, K и V с формами (
batch_size, seq_len, d_k
). - Затем мы масштабируем скалярное произведение Q и K (транспонированное). Это масштабирование включает деление результата скалярного произведения на корень из
d_k
и применение функции softmax к строкам матриц. Таким образом, оценки внимания для входного токена (т.е. строки) суммируются до единицы, что помогает предотвратить возможное увеличение значений и замедление вычислений. Выход имеет форму (batch_size, seq_len, seq_len
) - Наконец, мы домножаем результат предыдущей операции на V, получая форму вывода (
batch_size, seq_len, d_k
).
![Визуальное представление операций матрицы внутри блока внимания (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*IrLV4_NhMZ8MtWeC9xt0fA.png)
- Выходы каждой головки внимания затем могут быть конкатенированы для формирования матрицы с формой (
batch_size, seq_len, embed_dim
). В статье Transformer также добавляется линейный слой в конце многоhead-блока внимания, чтобы агрегировать и комбинировать изученные представления от всех головок внимания.
![Конкатенация матрицы многоhead-внимания и линейного слоя (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*FVx4Hffl1lfZo1Lm2Tbf0w.png)
В Haiku многоhead-блок внимания можно реализовать следующим образом. Функция __call__
следует той же логике, что и граф выше, а методы класса используют утилиты JAX, такие как vmap
(для векторизации операций над разными головками внимания и матрицами) и tree_map
(для применения матричного перемножения к весовым векторам).
Соединения остаточных и нормализация слоя
Как вы могли заметить на графе Transformer, блок многоhead-внимания и нейронная сеть с прямой связью за ними следуют остаточными соединениями и нормализацией слоя.
Остаточные или пропускные соединения
Остаточные соединения – это стандартное решение для решения проблемы исчезающего градиента, которая возникает, когда градиенты становятся слишком маленькими для эффективного обновления параметров модели.
Поскольку этот вопрос сам собой возникает в особенно глубоких архитектурах, остаточные соединения используются в различных сложных моделях, таких как ResNet (Kaiming et al, 2015) в области компьютерного зрения, AlphaZero (Silver et al, 2017) в обучении с подкреплением, и, конечно же, Трансформеры.
На практике остаточные соединения просто перенаправляют вывод определенного слоя к следующему, пропуская один или несколько слоев на пути. Например, остаточное соединение вокруг многоhead-внимания эквивалентно суммированию вывода многоhead-внимания с позиционными вложениями.
Это позволяет градиентам более эффективно протекать через архитектуру во время обратного распространения и обычно может привести к более быстрой сходимости и более стабильному обучению.
![Представление остаточных соединений в Трансформерах (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*_BL1Q4sT5kzmBQ7wBl7DnQ.jpeg)
Нормализация слоя
Нормализация слоя помогает гарантировать, что значения, передаваемые через модель, не “взрываются” (стремятся к бесконечности), что может легко происходить в блоках внимания, где каждый прямой проход включает в себя умножение нескольких матриц.
В отличие от нормализации по пакетам, которая нормализует по пакетному измерению, предполагая равномерное распределение, нормализация слоя работает по признакам. Такой подход подходит для пакетов предложений, в которых каждое предложение может иметь уникальные распределения из-за различных значений и словаря.
Нормализуя по признакам, таким как вложения или значения внимания, нормализация слоя стандартизирует данные до единообразного масштаба без смешения отдельных характеристик предложений при сохранении уникального распределения каждого.
![Изображение слоя нормализации в контексте трансформеров (сделано автором)](https://ai.miximages.com/miro.medium.com/v2/resize:fit:640/format:webp/1*NlJu3E6z-fZLExXGvTzcnA.jpeg)
Реализация слоя нормализации довольно проста, мы инициализируем обучаемые параметры альфа и бета и выполняем нормализацию вдоль выбранной оси признаков.
Сеть прямого распространения позиций
Последний компонент кодировщика, о котором нам нужно рассказать, это сеть прямого распространения позиций. Эта полностью связанная сеть берет нормализованные выходы блока внимания в качестве входных данных и используется для введения нелинейности и увеличения ёмкости модели для обучения сложных функций.
Она состоит из двух плотных слоев, разделенных активацией gelu:
После этого блока у нас есть еще одно соединение остатков и нормализация слоя для завершения кодировщика.
Итоги
Вот и всё! Теперь вы должны быть знакомы с основными концепциями кодировщика трансформера. Вот полный код класса кодировщика, обратите внимание, что в Haiku мы присваиваем имя каждому слою, чтобы обучаемые параметры были разделены и легко доступны. Функция __call__
является хорошим резюме различных этапов нашего кодировщика:
Чтобы использовать этот модуль на реальных данных, мы должны применить hk.transform
к функции, инкапсулирующей класс кодировщика. Действительно, вы, возможно, помните, что JAX принимает парадигму функционального программирования, поэтому Haiku следует тем же принципам.
Мы определяем функцию, содержащую экземпляр класса кодировщика, и возвращаем результат прямого прохода. Применение hk.transform
возвращает преобразованный объект со следующими функциями: init
и apply
.
Первая позволяет нам инициализировать модуль с помощью случайного ключа и некоторых фиктивных данных (обратите внимание, что здесь мы передаем массив нулей с формой batch_size, seq_len
), тогда как вторая позволяет обрабатывать реальные данные.
# Примечание: следующие две синтаксические конструкции эквивалентны# 1: Использование transform как декоратора класса@hk.transformdef encoder(x): ... return model(x) encoder.init(...)encoder.apply(...)# 2: Применение transfom отдельноdef encoder(x): ... return model(x)encoder_fn = hk.transform(encoder)encoder_fn.init(...)encoder_fn.apply(...)
В следующей статье мы завершим трансформер архитектурой, добавив декодер, который повторно использует большинство блоков, которые мы уже ввели, и узнаем, как обучать модель для конкретной задачи с использованием Optax!
Заключение
Благодарим вас за чтение до конца, если вас интересует экспериментирование с кодом, вы можете найти его полностью прокомментированным на GitHub, а также с дополнительными подробностями и обзором с использованием игрушечного набора данных.
GitHub – RPegoud/jab: Коллекция изначальных моделей глубокого обучения, реализованных в JAX
Коллекция изначальных моделей глубокого обучения, реализованных в JAX – GitHub – RPegoud/jab: Коллекция…
github.com
Если вы хотите погрузиться глубже в мир трансформеров, в следующем разделе представлены несколько статей, которые помогли мне написать эту статью.
До следующего раза 👋
Ссылки и ресурсы:
[1] Вся ваша потребность в внимании (2017), Васвани и др., Google
[2] Что же представляют собой ключи, запросы и значения в механизмах внимания? (2019) Stack Exchange
[3] Иллюстрированный Трансформер (2018), Джей Аламмар
[4] Нежное введение в позиционное кодирование в моделях Трансформера (2023), Мехрин Саид, Machine Learning Mastery
Источники изображений
- Word embeddings, developers.google.com
- Картинка кота, Карстен Вайнгиарт, Unsplash
- Ландшафт Норвегии, Паскаль Дебруннер, Unsplash
- Картинка собаки, Лоан, Unsplash