Внимание Объяснение основополагающих принципов

Основные принципы подробное объяснение

Flash Attention – это эффективная и точная техника ускорения модели Transformer, предложенная в 2022 году. FlashAttention путем восприятия операций чтения и записи в память достигает скорости работы на 2-4 раза быстрее, чем стандартное применение внимания в PyTorch, требуя всего 5-20% памяти.

В этой статье будет объяснены основные принципы Flash Attention, иллюстрирующие, как он достигает ускоренных вычислений и сэкономии памяти, не нарушая точность внимания.

Предварительные знания

Иерархия памяти GPU

Как показано на рисунке 1, память GPU состоит из нескольких модулей памяти с различными размерами и скоростями чтения/записи. Модули меньшего размера имеют более быстрые скорости чтения/записи.

Рисунок 1: Иерархия памяти GPU. Источник: [1]

Для GPU A100 память SRAM распределена по 108 потоковым мультипроцессорам, каждый процессор имеет размер 192K. Это составляет 192 * 108 КБ = 20 МБ. Оперативная память большой пропускной способности (HBM), которую обычно называют видеопамятью, имеет размер 40 ГБ или 80 ГБ.

Пропускная способность чтения/записи SRAM составляет 19 ТБ/с, тогда как пропускная способность чтения/записи HBM составляет всего 1,5 ТБ/с, менее чем 1/10 от SRAM.

Из-за увеличения скорости вычислений относительно скорости памяти операции все больше ограничиваются доступом к памяти (HBM). Поэтому снижение количества операций чтения/записи в HBM и эффективное использование более быстрой SRAM для вычислений имеет важное значение.

Модель выполнения GPU

GPU имеет большое количество потоков для выполнения операций (называемых ядрами). Каждое ядро загружает входные данные из HBM в регистры и SRAM, выполняет вычисления, а затем записывает выходные данные обратно в HBM.

Safe softmax

Для x = [x1, x2, …, xN] процесс вычисления наивного softmax показан в уравнении (1):

Однако на реальном аппаратном уровне диапазон вещественных чисел ограничен. Для float32 и bfloat16, когда x ≥ 89, результатом экспоненциации становится inf, что вызывает проблемы переполнения[3].

Чтобы избежать численного переполнения и обеспечить численную стабильность, обычно при выполнении вычислений вычитают максимальное значение, что известно как “safe