Реализация нейронного кодировщика Transformer с нуля с использованием JAX и Haiku 🤖

Создание нейронного кодировщика Transformer с нуля с применением JAX и Haiku 🤖

Понимание фундаментальных строительных блоков трансформеров.

Трансформеры в стиле Эдварда Хоппера (сгенерировано Dall.E 3)

Представленная в 2017 году в статье “Вниманием надо обойтись” [0], архитектура Трансформера, наверное, является одним из самых значимых прорывов в недавней истории глубокого обучения, позволяющим возникновение больших языковых моделей и находящим применение даже в таких областях, как компьютерное зрение.

Последующие лидирующие архитектуры, основанные на рекурренции, такие как сети долгой краткосрочной памяти (LSTM) или блоки с затворами (GRU), Трансформеры вводят понятие само-внимания, совмещенное с архитектурой кодера/декодера.

В этой статье мы пошагово и с нуля реализуем первую половину Трансформера – кодер. Мы будем использовать JAX в качестве основного фреймворка вместе с Haiku, одной из библиотек глубокого обучения DeepMind.

Если вы не знакомы с JAX или вам нужно освежить память о его удивительных функциональностях, я уже рассматривал эту тему в контексте обучения с подкреплением в моей предыдущей статье:

Векторизация и параллелизация сред RL с помощью JAX: Q-обучение на скорости света⚡

Научитесь векторизировать среду GridWorld и обучать 30 Q-обучаемых агентов параллельно на CPU на 1.8 миллиона шагов в минуту…

towardsdatascience.com

Мы рассмотрим каждый из блоков, составляющих кодер, и научимся эффективно их реализовывать. В частности, контур этой статьи содержит:

  • Слой вложения и позиционные кодировки
  • Многоголовое внимание
  • Соединения остатков и нормализация слоя
  • Сети прямой обратной связи по позициям

Отказ от ответственности: эта статья не претендует на полное введение в эти понятия, так как мы сосредоточимся в первую очередь на реализации. При необходимости, пожалуйста, обратитесь к ресурсам в конце этого сообщения.

Как всегда, полностью прокомментированный код для этой статьи, а также иллюстрированные блокноты, доступны на GitHub, не стесняйтесь добавить этот репозиторий в избранное, если вам понравилась статья!

GitHub – RPegoud/jab: Коллекция реализованных на JAX основных моделей глубокого обучения

Коллекция реализованных на JAX основных моделей глубокого обучения – GitHub – RPegoud/jab: Коллекция…

github.com

Основные параметры

Прежде чем мы начнем, нам нужно определить несколько параметров, которые сыграют решающую роль в блоке кодера:

  • Длина последовательности (seq_len): количество токенов или слов в последовательности.
  • Измерение вложений (embed_dim): размерность вложений, другими словами, количество числовых значений, используемых для описания одного токена или слова.
  • Размер пакета (batch_size): размер пакета входных данных, т.е. количество обрабатываемых последовательностей одновременно.

Входные последовательности в нашу модель кодировщика обычно имеют форму (batch_size, seq_len). В этой статье мы используем значения batch_size=32 и seq_len=10, что означает, что наш кодировщик одновременно обрабатывает 32 последовательности из 10 слов.

Уделяя внимание форме наших данных на каждом этапе обработки, мы сможем лучше визуализировать и понять, как данные протекают в блоке кодировщика. Вот краткий обзор нашего кодировщика, начнем снизу с слоя векторного представления и позиционного кодирования:

Представление блока кодировщика Transformer (сделано автором)

Слой векторного представления и позиционное кодирование

Как уже упоминалось ранее, наша модель берет пакетные последовательности токенов в качестве входных данных. Генерация этих токенов может быть настолько простой, как сбор набора уникальных слов в наборе данных и присвоение каждому из них индекса. Затем мы бы выбрали 32 последовательности из 10 слов и заменили бы каждое слово его индексом в словаре. Эта процедура даст нам массив формы (batch_size, seq_len), как и ожидалось.

Теперь мы готовы приступить к нашему кодировщику. Первый шаг – создать «позиционные вложения» для наших последовательностей. Позиционные вложения – это сумма векторных представлений слов и позиционного кодирования.

Векторные представления слов

Векторные представления слов позволяют нам кодировать значение и семантические связи между словами в нашем словаре. В этой статье размерность вложений фиксирована и составляет 64. Это означает, что каждое слово представлено вектором размерности 64, чтобы слова с похожими значениями имели похожие координаты. Более того, мы можем изменять эти векторы для извлечения связей между словами, как показано ниже.

Пример аналогий, полученных из векторных представлений слов (изображение с developers.google.com)

Используя Haiku, создание обучаемых вложений так же просто, как вызов:

hk.Embed(vocab_size, embed_dim)

Эти вложения будут обновляться вместе с другими обучаемыми параметрами во время обучения модели (подробнее об этом чуть позже).

Позиционное кодирование

В отличие от рекуррентных нейронных сетей, трансформеры не могут вывести позицию токена, исходя из общего скрытого состояния, так как у них отсутствуют рекуррентные или сверточные структуры. Поэтому введены позиционные кодирования, векторы, которые передают позицию токена в входной последовательности.

По сути, каждому токену присваивается позиционный вектор, состоящий из чередующихся значений синуса и косинуса. Эти векторы имеют ту же размерность, что и векторные представления слов, чтобы их можно было сложить.

В частности, в оригинальной статье о Transformer используются следующие функции:

Функции позиционного кодирования (перепечатка из «Attention is all you need», Васвани и др., 2017)

Ниже представлены данные, которые помогают нам лучше понять работу позиционных кодировок. Давайте посмотрим на первую строку верхнего графика, мы видим чередующиеся последовательности нулей и единиц. Действительно, строки представляют позицию токена в последовательности (переменная pos), а столбцы представляют размерность вложения (переменная i).

Поэтому, когда pos=0, предыдущие уравнения возвращают sin(0)=0 для четных размерностей вложения и cos(0)=1 для нечетных размерностей.

Кроме того, мы видим, что соседние строки имеют схожие значения, тогда как первая и последняя строки значительно отличаются. Это свойство полезно для модели для оценки расстояния между словами в последовательности, а также их порядка.

Наконец, третий график представляет сумму позиционных кодировок и вложений, которая является результатом блока вложения.

Представление вложений слов и позиционных кодировок, с seq_len=16 и embed_dim=64 (сделано автором)

Используя Haiku, мы определяем слой вложений следующим образом. Подобно другим фреймворкам глубокого обучения, Haiku позволяет нам определить пользовательские модули (здесь hk.Module) для хранения обучаемых параметров и определения поведения компонентов нашей модели.

Каждому модулю Haiku нужно иметь функцию __init__и __call__. Здесь функция call просто вычисляет вложения с помощью функции hk.Embed и позиционных кодировок, а затем суммирует их.

Функция позиционного кодирования использует функции JAX, такие как vmapи lax.condдля оптимальной производительности. Если вы не знакомы с этими функциями, не стесняйтесь посмотреть мою предыдущую публикацию, где они подробно рассмотрены.

Простыми словами, vmapпозволяет нам определить функцию для одного образца и векторизовать ее, чтобы она могла быть применена к пакетам данных. Параметр in_axes используется, чтобы указать, что мы хотим итерироваться по первой оси ввода dim, которая является размерностью вложения. С другой стороны, lax.cond – это совместимая с XLA версия оператора if/else в Python.

Внимание на себя и многоголовое внимание

Внимание направлено на вычисление важности каждого слова в последовательности, относительно входного слова. Например, в предложении:

“Черный кот прыгнул на диван, улегся и заснул, так как стал уставшим”.

Слово “он” может быть довольно неоднозначным для модели, так как технически оно может относиться как к “коту”, так и к “дивану”. Хорошо обученная модель внимания смогла бы понять, что “он” относится к “коту“, и соответствующим образом присвоила бы значения внимания остальной части предложения.

По сути, значения внимания можно рассматривать как веса, которые описывают важность определенного слова с учетом контекста входного слова. Например, вектор внимания для слова “прыгнул” будет иметь высокие значения для слов, таких как “кот” (что прыгнуло?), “на” и “диван” (куда прыгнуло?) , поскольку эти слова являются существенными для его контекста.

Визуальное представление вектора внимания (сделано автором)

В статье Transformer внимание вычисляется с использованием Dimensionality Scaled Dot-Product Attention, что можно представить формулой:

Dimensionality Scaled Dot-Product Attention (воспроизведено из «Внимание - все, что вам нужно», Васвани и др., 2017)

Здесь Q, K и V обозначают Запросы, Ключи и Значения. Эти матрицы получаются путем умножения векторов весов WQ, WK и WV на позиционные вложения.

Эти имена в основном являются абстракциями, используемыми для понимания того, как информация обрабатывается и взвешивается в блоке внимания. Они отсылают к словарю систем извлечения (например, поиск видео на YouTube).

Вот интуитивное объяснение:

  • Запросы: Их можно интерпретировать как “набор вопросов” о всех позициях в последовательности. Например, исследование контекста слова и попытка определить наиболее связанные части последовательности.
  • Ключи: Их можно рассматривать как информацию, с которой взаимодействуют запросы. Совместимость между запросом и ключом определяет, насколько важно запросу обратить внимание на соответствующее значение.
  • Значения: Сопоставление ключей и запросов позволяет нам решить, какие ключи являются значимыми, а значения – фактическим содержанием, сопоставленным с ключами.

На следующей иллюстрации запрос представляет собой поиск на YouTube, ключи – описания видео и метаданные, а значения – соответствующие видео.

Интуитивное изображение концепции Запросы, Ключи, Значения (сделано автором)

В нашем случае запросы, ключи и значения берутся из одного источника (поскольку они получены из входных последовательностей), отсюда и название самовнимание.

Вычисление оценок внимания обычно выполняется параллельно несколько раз, каждый раз с использованием части вложений. Этот механизм называется “Многоголовое внимание” и позволяет каждой голове параллельно изучать несколько различных представлений данных, что приводит к более устойчивой модели.

Одна голова внимания обычно обрабатывает массивы со формой (batch_size, seq_len, d_k), где d_k может быть установлено как отношение количества голов и размерности вложений (d_k = n_heads/embed_dim). Таким образом, объединение результатов каждой головы удобно дает массив, имеющий форму (batch_size, seq_len, embed_dim) в качестве ввода.

Вычисление матриц внимания можно разбить на несколько шагов:

  • Во-первых, определяем обучаемые векторы весов WQ, WK и WV. Эти векторы имеют формы (n_heads, embed_dim, d_k).
  • Параллельно с этим, мы умножаем позиционные вложения на векторы весов. Мы получаем матрицы Q, K и V с формами (batch_size, seq_len, d_k).
  • Затем мы масштабируем скалярное произведение Q и K (транспонированное). Это масштабирование включает деление результата скалярного произведения на корень из d_k и применение функции softmax к строкам матриц. Таким образом, оценки внимания для входного токена (т.е. строки) суммируются до единицы, что помогает предотвратить возможное увеличение значений и замедление вычислений. Выход имеет форму (batch_size, seq_len, seq_len)
  • Наконец, мы домножаем результат предыдущей операции на V, получая форму вывода (batch_size, seq_len, d_k).
Визуальное представление операций матрицы внутри блока внимания (сделано автором)
  • Выходы каждой головки внимания затем могут быть конкатенированы для формирования матрицы с формой (batch_size, seq_len, embed_dim). В статье Transformer также добавляется линейный слой в конце многоhead-блока внимания, чтобы агрегировать и комбинировать изученные представления от всех головок внимания.
Конкатенация матрицы многоhead-внимания и линейного слоя (сделано автором)

В Haiku многоhead-блок внимания можно реализовать следующим образом. Функция __call__ следует той же логике, что и граф выше, а методы класса используют утилиты JAX, такие как vmap (для векторизации операций над разными головками внимания и матрицами) и tree_map (для применения матричного перемножения к весовым векторам).

Соединения остаточных и нормализация слоя

Как вы могли заметить на графе Transformer, блок многоhead-внимания и нейронная сеть с прямой связью за ними следуют остаточными соединениями и нормализацией слоя.

Остаточные или пропускные соединения

Остаточные соединения – это стандартное решение для решения проблемы исчезающего градиента, которая возникает, когда градиенты становятся слишком маленькими для эффективного обновления параметров модели.

Поскольку этот вопрос сам собой возникает в особенно глубоких архитектурах, остаточные соединения используются в различных сложных моделях, таких как ResNet (Kaiming et al, 2015) в области компьютерного зрения, AlphaZero (Silver et al, 2017) в обучении с подкреплением, и, конечно же, Трансформеры.

На практике остаточные соединения просто перенаправляют вывод определенного слоя к следующему, пропуская один или несколько слоев на пути. Например, остаточное соединение вокруг многоhead-внимания эквивалентно суммированию вывода многоhead-внимания с позиционными вложениями.

Это позволяет градиентам более эффективно протекать через архитектуру во время обратного распространения и обычно может привести к более быстрой сходимости и более стабильному обучению.

Представление остаточных соединений в Трансформерах (сделано автором)

Нормализация слоя

Нормализация слоя помогает гарантировать, что значения, передаваемые через модель, не “взрываются” (стремятся к бесконечности), что может легко происходить в блоках внимания, где каждый прямой проход включает в себя умножение нескольких матриц.

В отличие от нормализации по пакетам, которая нормализует по пакетному измерению, предполагая равномерное распределение, нормализация слоя работает по признакам. Такой подход подходит для пакетов предложений, в которых каждое предложение может иметь уникальные распределения из-за различных значений и словаря.

Нормализуя по признакам, таким как вложения или значения внимания, нормализация слоя стандартизирует данные до единообразного масштаба без смешения отдельных характеристик предложений при сохранении уникального распределения каждого.

Изображение слоя нормализации в контексте трансформеров (сделано автором)

Реализация слоя нормализации довольно проста, мы инициализируем обучаемые параметры альфа и бета и выполняем нормализацию вдоль выбранной оси признаков.

Сеть прямого распространения позиций

Последний компонент кодировщика, о котором нам нужно рассказать, это сеть прямого распространения позиций. Эта полностью связанная сеть берет нормализованные выходы блока внимания в качестве входных данных и используется для введения нелинейности и увеличения ёмкости модели для обучения сложных функций.

Она состоит из двух плотных слоев, разделенных активацией gelu:

После этого блока у нас есть еще одно соединение остатков и нормализация слоя для завершения кодировщика.

Итоги

Вот и всё! Теперь вы должны быть знакомы с основными концепциями кодировщика трансформера. Вот полный код класса кодировщика, обратите внимание, что в Haiku мы присваиваем имя каждому слою, чтобы обучаемые параметры были разделены и легко доступны. Функция __call__ является хорошим резюме различных этапов нашего кодировщика:

Чтобы использовать этот модуль на реальных данных, мы должны применить hk.transform к функции, инкапсулирующей класс кодировщика. Действительно, вы, возможно, помните, что JAX принимает парадигму функционального программирования, поэтому Haiku следует тем же принципам.

Мы определяем функцию, содержащую экземпляр класса кодировщика, и возвращаем результат прямого прохода. Применение hk.transform возвращает преобразованный объект со следующими функциями: init и apply.

Первая позволяет нам инициализировать модуль с помощью случайного ключа и некоторых фиктивных данных (обратите внимание, что здесь мы передаем массив нулей с формой batch_size, seq_len), тогда как вторая позволяет обрабатывать реальные данные.

# Примечание: следующие две синтаксические конструкции эквивалентны# 1: Использование transform как декоратора класса@hk.transformdef encoder(x):  ...  return model(x)  encoder.init(...)encoder.apply(...)# 2: Применение transfom отдельноdef encoder(x):  ...  return model(x)encoder_fn = hk.transform(encoder)encoder_fn.init(...)encoder_fn.apply(...)

В следующей статье мы завершим трансформер архитектурой, добавив декодер, который повторно использует большинство блоков, которые мы уже ввели, и узнаем, как обучать модель для конкретной задачи с использованием Optax!

Заключение

Благодарим вас за чтение до конца, если вас интересует экспериментирование с кодом, вы можете найти его полностью прокомментированным на GitHub, а также с дополнительными подробностями и обзором с использованием игрушечного набора данных.

GitHub – RPegoud/jab: Коллекция изначальных моделей глубокого обучения, реализованных в JAX

Коллекция изначальных моделей глубокого обучения, реализованных в JAX – GitHub – RPegoud/jab: Коллекция…

github.com

Если вы хотите погрузиться глубже в мир трансформеров, в следующем разделе представлены несколько статей, которые помогли мне написать эту статью.

До следующего раза 👋

Ссылки и ресурсы:

[1] Вся ваша потребность в внимании (2017), Васвани и др., Google

[2] Что же представляют собой ключи, запросы и значения в механизмах внимания? (2019) Stack Exchange

[3] Иллюстрированный Трансформер (2018), Джей Аламмар

[4] Нежное введение в позиционное кодирование в моделях Трансформера (2023), Мехрин Саид, Machine Learning Mastery

Источники изображений