Векторизуйте и параллельно обрабатывайте среды RL с помощью JAX обучение с применением Q-обучения со скоростью света ⚡

Улучшите обработку среды RL с помощью векторизации и параллельного обучения JAX, используя Q-обучение со скоростью света ⚡

В этой статье мы узнаем, как векторизовать среду RL и тренировать 30 агентов Q-learning параллельно на ЦП, с частотой в 1,8 миллиона итераций в секунду.

Изображение от Google DeepMind на Unsplash

В предыдущей истории мы познакомились с Temporal-Difference Learning, в частности с Q-learning, в контексте GridWorld.

Обучение методом временной разности и важность исследования: Иллюстрированный руководитель

Сравнение методов моделирования (Q-learning) и моделирования на основе моделей (Dyna-Q и Dyna-Q+) в задаче динамического грида.

towardsdatascience.com

Хотя этая реализация служила цели показать различия в производительности и механизмах исследования этих алгоритмов, она работала ужасно медленно.

Фактически, среда и агенты были в основном написаны на Numpy, которая не является стандартом в области RL, хотя облегчает понимание и отладку кода.

В этой статье мы увидим, как масштабировать эксперименты RL путем векторизации среды и без проблем параллелить тренировку десятков агентов, используя JAX. В частности, в этой статье рассматриваются следующие темы:

  • Основы JAX и полезные возможности для RL
  • Векторизованная среда и причины их высокой производительности
  • Реализация среды, политики и агента Q-learning в JAX
  • Тренировка с одним агентом
  • Как параллелить тренировку агента и насколько это просто!

Весь код, представленный в этой статье, доступен на GitHub:

GitHub – RPegoud/jax_rl: Реализация алгоритмов RL и векторизованных сред в JAX

Реализация алгоритмов RL и векторизованных сред в JAX – GitHub – RPegoud/jax_rl: Реализация алгоритмов RL…

github.com

Основы JAX

JAX – это еще один фреймворк глубокого обучения на языке Python, разработанный Google и широко используемый такими компаниями, как DeepMind.

«JAX – это Autograd (автоматическое дифференцирование) и XLA (ускоренная линейная алгебра, компилятор TensorFlow), объединенные для высокопроизводительных численных вычислений». – Официальная документация

В отличие от того, с чем большинство разработчиков Python привыкли работать, JAX не основан на парадигме объектно-ориентированного программирования (OOP), а скорее на функциональном программировании (FP)[1].

Простыми словами, он опирается на чистые функции (детерминированные и без побочных эффектов) и неизменяемые структуры данных (вместо изменения данных на месте, создаются новые структуры данных с нужными изменениями) в качестве основных строительных блоков. В результате FP способствует более функциональному и математическому подходу к программированию, делая его хорошо подходящим для таких задач, как численные вычисления и обучение машин.

Давайте проиллюстрируем различия между этими двумя парадигмами, рассмотрев псевдокод для функции обновления Q-значений:

  • Подход, основанный на объектно-ориентированном программировании, полагается на экземпляр класса, содержащий различные переменные состояния (например, Q-значения). Функция обновления определена как метод класса, который обновляет внутреннее состояние экземпляра.
  • Подход, основанный на функциональном программировании, полагается на чистую функцию. Действительно, данное обновление Q-значений является детерминированным, поскольку Q-значения передаются в качестве аргумента. Таким образом, любой вызов этой функции с одинаковыми входными данными приведет к получению одинаковых выходных данных, в то время как выходные данные метода класса могут зависеть от внутреннего состояния экземпляра. Кроме того, структуры данных, такие как массивы, определяются и изменяются в глобальной области видимости.
Реализация обновления Q-значений в объектно-ориентированном и функциональном программировании (сделано автором)

В этом контексте JAX предлагает различные декораторы функций, которые особенно полезны в контексте RL:

  • vmap (векторизованное отображение): Позволяет применять функцию, действующую на один образец, к набору. Например, если env.step() является функцией, выполняющей шаг в одной среде, то vmap(env.step)() – это функция, выполняющая шаг в нескольких средах. Другими словами, vmap добавляет измерение набора к функции.
Иллюстрация функции шага, векторизованной с использованием vmap (сделано автором)
  • jit (just-in-time компиляция): Позволяет JAX выполнять «Just In Time компиляцию JAX Python функции», делая ее совместимой с XLA. Фактически, использование jit позволяет компилировать функции и обеспечивает значительное увеличение скорости (в обмен на некоторые дополнительные накладные расходы при первоначальной компиляции функции).
  • pmap (параллельное отображение): Похоже на vmap, pmap обеспечивает простую параллелизацию. Однако, вместо добавления размерности набора функции, она копирует функцию и выполняет ее на нескольких устройствах XLA. Примечание: при применении pmap jit также применяется автоматически.
Иллюстрация функции шага, параллельно выполняющейся с использованием pmap (сделано автором)

Теперь, когда мы ознакомились с основами JAX, мы увидим, как достичь огромного ускорения, векторизуя среды.

Векторизованные среды:

Во-первых, что такое векторизованная среда и какие проблемы она решает?

В большинстве случаев, эксперименты RL замедляются из-за передач данных между CPU и GPU. Алгоритмы глубокого обучения RL, такие как Проксимальная оптимизация политики (PPO), используют нейронные сети для аппроксимации политики.

Как всегда в глубоком обучении, нейронные сети используют GPU во время тренировки и выполнения. Однако, в большинстве случаев среды работают на CPU (даже при использовании нескольких сред параллельно).

Это означает, что обычный цикл RL, в котором выбираются действия с помощью политики (нейронные сети) и получаются наблюдения и вознаграждения от среды, требует постоянного обмена данными между GPU и CPU, что снижает производительность.

Кроме того, использование таких фреймворков, как PyTorch без “jit-компиляции”, может вызывать излишнюю нагрузку, так как GPU может ожидать возвращения наблюдений и вознаграждений от CPU.

Обычная настройка пакетной тренировки RL в PyTorch (сделано автором)

С другой стороны, JAX позволяет нам легко выполнять пакетные среды на GPU, устраняя трения, вызванные передачей данных между GPU и CPU.

Более того, поскольку jit компилирует наш код JAX в XLA, выполнение больше не зависит (или по крайней мере в меньшей степени) от неэффективности Python.

Настройка пакетной тренировки RL в JAX (сделано автором)

Для получения дополнительных деталей и захватывающих применений в исследованиях мета-обучения RL я настоятельно рекомендую эту статью блога от Криса Лу.

Реализация среды, агента и политики:

Давайте рассмотрим реализацию разных частей нашего эксперимента по RL. Вот общий обзор основных функций, которые нам понадобятся:

Методы класса для простой настройки RL (сделано автором)

Среда

Эта реализация следует схеме, предложенной Николаем Гуджером в его великой статье о написании среды в JAX.

Написание среды для RL в JAX

Как запустить CartPole на скорости 1.25 миллиарда шагов/сек

VoAGI.com

Давайте начнем с подробного обзора среды и ее методов. Это общий план реализации среды в JAX:

Давайте ближе рассмотрим методы класса (напомню, что функции, начинающиеся с “_”, являются приватными и их не следует вызывать за пределами класса):

  • _get_obs: Этот метод преобразует состояние среды в наблюдение для агента. В случае с частично наблюдаемой или стохастической средой здесь могут быть применены функции обработки состояния.
  • _reset: Поскольку мы будем запускать несколько агентов параллельно, нам нужен метод для индивидуального сброса по завершении эпизода.
  • _reset_if_done: Этот метод будет вызываться на каждом шаге и вызывать _reset, если флаг “done” установлен в True.
  • reset: Этот метод вызывается в начале эксперимента для получения начального состояния каждого агента, а также связанных случайных ключей.
  • step: По заданному состоянию и действию среда возвращает наблюдение (новое состояние), вознаграждение и обновленный флаг “done”.

На практике обычная реализация GridWorld среды будет выглядеть так:

Обратите внимание, как уже упоминалось, все методы класса следуют парадигме функционального программирования. Действительно, мы никогда не обновляем внутреннее состояние экземпляра класса. Кроме того, атрибуты класса – все это константы, которые не будут изменяться после создания.

Давайте рассмотрим ближе:

  • __init__: В контексте нашей GridWorld доступны действия [0, 1, 2, 3]. Эти действия переводятся в двумерный массив с использованием self.movements и добавляются к состоянию в функции step.
  • _get_obs: Наша среда является детерминированной и полностью наблюдаемой, поэтому агент получает состояние напрямую, а не обработанное наблюдение.
  • _reset_if_done: Аргумент env_state соответствует кортежу (state, key), где key является jax.random.PRNGKey. Эта функция просто возвращает начальное состояние, если флаг “завершено” установлен в True, однако мы не можем использовать обычное управление потоком в Python внутри JAX jitted функций. Используя jax.lax.cond, мы фактически получаем выражение эквивалентное:
def cond(condition, true_fun, false_fun, operand):  if condition: # если флаг done == True    return true_fun(operand)  # возвращаем self._reset(key)  else:    return false_fun(operand) # возвращаем env_state
  • step: Мы преобразуем действие в движение и добавляем его к текущему состоянию (jax.numpy.clip гарантирует, что агент остается в пределах сетки). Затем мы обновляем кортеж env_state перед проверкой, нужно ли сбросить среду. Поскольку функция шага часто используется во время обучения, она может значительно улучшить производительность. Декоратор @partial(jit, static_argnums=(0, ) сигнализирует о том, что аргумент “self” метода класса следует рассматривать как статический. Другими словами, свойства класса являются постоянными и не изменятся во время последовательных вызовов функции шага.

Q-обучение Агента

Агент Q-обучение определяется функцией обновления, а также статической скоростью обучения и фактором дисконтирования.

Опять же, при jit-компиляции функции обновления мы передаем аргумент “self” как статический. Также обратите внимание, что матрица q_values изменяется на месте с помощью set() и ее значение не сохраняется в виде атрибута класса.

Epsilon-Greedy Политика

И, наконец, политика, используемая в этом эксперименте, является стандартной epsilon-greedy политикой. Один важный момент – использование случайных связей, что означает, что если максимальное значение Q-функции не является уникальным, действие будет выбрано равновероятно из максимальных значений Q-функции (использование argmax всегда вернет первое действие с максимальным значением Q-функции). Это особенно важно, если значения Q-функций инициализированы как матрица из нулей, так как всегда выбирается действие 0 (вправо).

В противном случае политику можно суммировать с помощью следующего отрывка:

action = lax.cond(            explore, # если p < epsilon            _random_action_fn, # выбрать случайное действие с учетом ключа            _greedy_action_fn, # выбрать жадное действие в соответствии с Q-функциями            operand=subkey, # использовать subkey в качестве аргумента для вышеперечисленных функций        )return action, subkey

Обратите внимание, что при использовании ключа в JAX (например, в этом случае мы сгенерировали случайную дробь и использовали random.choice), обычной практикой является разделение ключа после этого (т.е. “переход к новому случайному состоянию”, более подробно здесь).

Цикл обучения для одного агента:

Теперь, когда у нас есть все необходимые компоненты, давайте обучим одного агента.

Вот Pythonic цикл обучения, как вы можете видеть, мы в основном выбираем действие с помощью политики, выполняем шаг в среде и обновляем значения Q-функций до конца эпизода. Затем мы повторяем процесс для N эпизодов. Как мы увидим через минуту, такой способ обучения агента довольно неэффективен, однако он кратко описывает ключевые шаги алгоритма в понятной форме:

На одном центральном процессоре мы завершаем 10 000 эпизодов за 11 секунд, с частотой в 881 эпизод и 21 680 шагов в секунду.

100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]Общее количество шагов: 238 488Количество шагов в секунду: 21 680

Теперь продублируем ту же цикл обучения, используя синтаксис JAX. Вот высокоуровневое описание функции rollout:

Функция rollout обучения с использованием синтаксиса JAX (сделано автором)

Вкратце, функция rollout:

  1. Инициализирует массивы наблюдений, наград и флагов завершения значением пустого массива с размерностью, равной количеству шагов, используя jax.numpy.zeros. Значения Q инициализируются как пустая матрица с формой [количество_шагов+1, размерность_сетки_x, размерность_сетки_y, n_действий].
  2. Вызывает функцию env.reset() для получения начального состояния
  3. Использует функцию jax.lax.fori_loop() для вызова функции fori_body() N раз, где N – это параметр количество_шагов
  4. Функция fori_body() ведет себя аналогично предыдущему питоновскому циклу. После выбора действия, совершения шага и вычисления обновления Q, мы обновляем массивы наблюдений, наград, флагов завершения и q-значений в месте (обновление Q-значения ориентируется на временной шаг t+1).

Эта дополнительная сложность приводит к увеличению скорости в 85 раз, теперь мы обучаем нашего агента примерно с 1,83 миллиона шагов в секунду. Обратите внимание, что здесь обучение производится на одном центральном процессоре, поскольку среда является простой.

Однако, полностью векторизованное выполнение масштабируется еще лучше, когда применяется к сложным средам и алгоритмам, в которых выигрыш дается множеством GPU (статья Криса Лу рассказывает о более чем внушительном увеличении скорости в 4000 раз между реализацией PPO на базе CleanRL на PyTorch и реализацией на JAX).

100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]Общее количество шагов: 1 000 000Количество шагов в секунду: 1 837 563

После обучения нашего агента, мы строим максимальное значение Q-значения для каждой ячейки (т.е. состояния) GridWorld, и видим, что он действительно научился идти от начального состояния (нижний правый угол) к цели (верхний левый угол).

Тепловая карта максимального значения Q-значения для каждой ячейки GridWorld (сделано автором)

Цикл обучения параллельных агентов:

Как обещано, теперь, когда мы написали функции для обучения одного агента, нам осталось мало работы, чтобы обучить несколько агентов параллельно в пакетных средах!

Благодаря vmap мы можем быстро преобразовать наши предыдущие функции для работы с пакетом данных. Нам нужно только указать ожидаемые формы входных и выходных данных, например, для env.step:

  • in_axes = ((0,0), 0) представляет форму ввода, которая состоит из кортежа env_state (размерность (0, 0)) и наблюдения (размерность 0).
  • out_axes = ((0, 0), 0, 0, 0) представляет форму вывода, где вывод является ((env_state), наблюдение, награда, флаги завершения).
  • Теперь мы можем вызвать v_step на массиве env_states и действиях и получить массив обработанных env_states, наблюдений, наград и флагов завершения.
  • Обратите внимание, что мы также используем jit все функции с пакетной обработкой для повышения производительности (возможно, нативная компиляция не требуется для функции env.reset(), поскольку она вызывается только один раз в нашей функции обучения).

Последнее изменение, которое нам нужно сделать, это добавить размер пакета к нашим массивам, чтобы учесть данные каждого агента.

Делая это, мы получаем функцию, которая позволяет нам обучать несколько агентов параллельно, с минимальными изменениями по сравнению с функцией для одного агента:

Мы получаем аналогичную производительность с этой версией нашей обучающей функции:

100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]Общее количество шагов: 100 000 * 30 = 3 000 000Количество шагов в секунду: 49 036 * 30 = 1 471 080

Вот и все! Спасибо, что дочитали до конца, я надеюсь, что этот статья была полезным введением в реализацию векторизованных окружений в JAX.

Если вам понравилось чтение, пожалуйста, подумайте о поделиться этой статьей и оцените мой репозиторий на GitHub. Спасибо за вашу поддержку!

GitHub – RPegoud/jax_rl: Реализация алгоритмов RL и векторизованных окружений в JAX

Реализация алгоритмов RL и векторизованных окружений в JAX – GitHub – RPegoud/jax_rl: Реализация алгоритмов RL…

github.com

Наконец, для тех, кто заинтересован копнуть немного глубже, вот список полезных ресурсов, которые помогли мне начать работу с JAX и написать эту статью:

Кураторский список потрясающих статей и ресурсов по JAX:

[1] Coderized, (функциональное программирование) Самый чистый стиль кодирования, в котором почти невозможны ошибки, YouTube

[2] Aleksa Gordić, JAX From Zero to Hero YouTube Playlist (2022), The AI Epiphany

[3] Nikolaj Goodger, Writing an RL Environment in JAX (2021)

[4] Chris Lu, Achieving 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL (2023), University of Oxford, Foerster Lab for AI Research

[5] Nicholas Vadivelu, Awesome-JAX (2020), список библиотек, проектов и ресурсов по JAX

[6] Официальная документация JAX, Training a Simple Neural Network, with PyTorch Data Loading