Усиление обучения с помощью обратной связи человека (RLHF).

Преобразование обучения с помощью обратной связи от человека (RLHF).

Простое объяснение

Возможно, вы слышали об этой технике, но не полностью понимаете ее, особенно часть про PPO. Это объяснение может помочь.

Мы сосредоточимся на моделях языка текст-в-текст 📝, таких как GPT-3, BLOOM и T5. Модели, такие как BERT, которые представляют только кодировщик, не рассматриваются.

Этот блог является адаптацией гиста от того же автора.

Обучение с подкреплением с помощью обратной связи от человека (RLHF) успешно применяется в ChatGPT, что объясняет его значительный рост популярности. 📈

Алгоритм RLHF особенно полезен в двух сценариях 🌟:

  • Невозможно создать хорошую функцию потерь (например, как рассчитать метрику для измерения того, был ли забавным вывод модели?)
  • Вы хотите обучать модель с использованием производственных данных, но вы не можете легко разметить эти данные (например, как получить размеченные производственные данные от ChatGPT? Кто-то должен написать правильный ответ, который ChatGPT должен был бы подать)

Алгоритм RLHF ⚙️:

  1. Предварительное обучение языковой модели (LM)
  2. Обучение модели вознаграждения
  3. Дообучение LM с использованием RL

1 — Предварительное обучение языковой модели (LM)

На этом этапе вам нужно либо обучить языковую модель с нуля, либо использовать готовую модель, например, GPT-3.

После получения предварительно обученной языковой модели вы также можете выполнить дополнительный необязательный шаг, называемый Супервизированное дообучение (STF). Это просто получение некоторых пар текста, размеченных человеком, и дообучение у вас уже имеющейся языковой модели. СТФ считается высококачественной инициализацией для RLHF.

По окончании этого этапа мы получаем обученную LM, которая является нашей основной моделью, той, которую мы хотим дообучить с помощью RLHF.

Рисунок 1: Наша предварительно обученная языковая модель.

2 — Обучение модели вознаграждения

На этом этапе мы хотим собрать набор данных троек (входной текст, выходной текст, вознаграждение).

На рисунке 2 представлена схема сбора данных: используйте входные текстовые данные (если это производственные данные, лучше), пропустите их через вашу модель и позвольте человеку оценить созданный выходной текст.

Рисунок 2: Схема сбора данных для обучения модели вознаграждения.

Вознаграждение обычно представляет собой целое число от 0 до 5, но может быть простым 0/1 в виде 👍/👎.

Рисунок 3: Простой сбор вознаграждения в виде 👍/👎 в ChatGPT.
Рисунок 4: Более полный опыт сбора вознаграждения: модель выводит два текста, и человек должен выбрать лучший, а также оценить в целом с комментариями.

С помощью этого нового набора данных мы обучим другую модель языка, чтобы получать (ввод, вывод) текст и возвращать скалярное вознаграждение! Это будет наша модель вознаграждения.

Основная цель здесь – использовать модель вознаграждения для имитации маркировки вознаграждения человека и, таким образом, иметь возможность проводить тренировку RLHF автономно, без участия человека.

Рисунок 5: Обученная модель вознаграждения, которая будет имитировать вознаграждения, полученные от людей.

3 — Калибровка модели языка с помощью RL

Именно на этом этапе происходит настоящая магия и вступает RL в игру.

Цель этого этапа – использовать вознаграждения, предоставленные моделью вознаграждения, для обучения основной модели – вашей обученной модели языка. Однако, поскольку вознаграждение не будет дифференцируемым, нам потребуется использовать RL, чтобы сможем построить потери, которые мы сможем обратно распространить к модели языка.

Рисунок 6: Калибровка основной модели языка с использованием модели вознаграждения и вычисления потерь PPO.

На начальном этапе конвейера мы сделаем точную копию нашей модели языка и заморозим обучаемые веса. Эта копия модели поможет предотвратить полное изменение весов обучаемой модели языка и начать выводить бессмысленный текст, чтобы запутать модель вознаграждения.

Вот почему мы вычисляем потери Кульбака-Лейблера между вероятностями вывода текста как у замороженной, так и у незамороженной модели языка.

Эти потери КЛ добавляются к вознаграждению, полученному от модели вознаграждения. Фактически, если вы тренируете модель во время производства (онлайн-обучение), вы можете заменить эту модель вознаграждения прямым полученным от оценки человека. 💡

Имея ваше вознаграждение и потерю КЛ, теперь мы можем применить RL, чтобы сделать потерю вознаграждения дифференцируемой.

Почему вознаграждение не является дифференцируемым? Потому что оно было рассчитано с помощью модели вознаграждения, которая получала входной текст. Этот текст получается путем декодирования вероятностей вывода текста модели языка. Этот процесс декодирования не является дифференцируемым.

Чтобы сделать потерю дифференцируемой, в конечном итоге вступает в игру Proximal Policy Optimization (PPO)! Давайте подробнее рассмотрим.

Рисунок 7: Приближение к обновлению RL - вычисление потерь PPO.

Алгоритм PPO вычисляет потери (которые будут использоваться для небольшого обновления модели языка) следующим образом:

  1. Установите “Начальные вероятности” равными “Новым вероятностям” для инициализации.
  2. Вычислите отношение между новыми и начальными вероятностями вывода текста.
  3. Вычислите потерю по формуле loss = -min(ratio * R, clip(ratio, 0.8, 1.2) * R), где R – это вознаграждение + КЛ (или взвешенное среднее, например, 0.8 * вознаграждение + 0.2 * КЛ), ранее рассчитанные, а clip(ratio, 0.8, 1.2) просто ограничивает соотношение в диапазоне 0.8 <= ratio <= 1.2. Обратите внимание, что 0.8/1.2 – это просто общепринятые значения гиперпараметров, которые здесь упрощены. Также обратите внимание, что мы хотим максимизировать вознаграждение, поэтому добавляем знак минус - для минимизации отрицания потерь с помощью градиентного спуска.
  4. Обновите веса модели языка, обратно распространяя потери.
  5. Вычислите “Новые вероятности” (т.е. новые вероятности вывода текста) с обновленной моделью языка.
  6. Повторите с шага 2 до N раз (обычно N=4).

Вот и все, вот как вы используете RLHF в языковых моделях текста-к-тексту!

Вещи могут стать более сложными, потому что существуют и другие потери, которые вы можете добавить к этой базовой потере, которую я представил, но это основная реализация.