Неправильна ли реализация метода Nesterov Momentum в PyTorch?

Неправильна ли реализация Nesterov Momentum в PyTorch?

Momentum помогает SGD более эффективно перемещаться по сложным ландшафтам потерь. Фото Максим Берг на Unsplash.

Введение

Если вы внимательно посмотрите на документацию PyTorch по SGD, вы обнаружите, что их реализация имеет несколько отличий от формулировки, найденной в оригинальной статье об Nesterov моменте. В основном, реализация PyTorch оценивает градиент в текущих параметрах, тогда как главная цель Nesterov момента состоит в оценке градиента в сдвинутых параметрах. К сожалению, кажется, что обсуждение этих расхождений в интернете является редким. В этом посте мы рассмотрим и объясним различия между реализацией PyTorch и оригинальной формулировкой Nesterov момента. В конечном итоге, мы увидим, что реализация PyTorch не является неправильной, а скорее приближением, и будем предполагать о пользе их реализации.

Формулировки

В оригинальной статье описывается Nesterov момент с использованием следующих правил обновления:

где v_{t+1} и θ_{t+1} – это вектор скорости и параметры модели соответственно в момент времени t, μ – это фактор момента, а ε – это скорость обучения. Заметка в документации SGD PyTorch гласит, что они используют следующие правила обновления:

где g_{t+1} представляет градиент, используемый для вычисления v_{t+1}. Мы можем расширить правило обновления для θ_{t+1}:

Из этого мы можем заключить, что:

и правила обновления становятся:

Это правила обновления, которые PyTorch использует в теории. Я ранее упомянул, что PyTorch фактически оценивает градиент в текущих параметрах вместо сдвинутых параметров. Это можно увидеть, посмотрев на описание алгоритма в документации PyTorch SGD. Мы рассмотрим это подробнее позже.

Обратите внимание, что для обоих оригинальных (1, 2) и PyTorch (3, 4) формулировок, если v_0 = 0, то первое обновление θ становится:

Хотя заметка в документации PyTorch SGD указывает, что алгоритм инициализирует буфер момента градиентом на первом шаге, мы покажем позже, что это означает v_0 = 0.

Предварительные различия

Есть два немедленных различия при переходе от оригинальных (1, 2) к PyTorch (3, 4) формулировке:

  1. Скорость обучения вынесена за пределы v_{t+1}.
  2. В правиле обновления для v_{t+1} добавляется терм, связанный с градиентом, вместо вычитания, а в правиле обновления для θ_{t+1} вычитается терм, связанный с вектором скорости, вместо сложения. Различие в знаке внутри терма градиента является простым следствием этого, как показано в предыдущем разделе.

Чтобы понять эти различия, давайте сначала расширим правила обновления. Как намекается здесь, эффект первого различия более очевиден, если мы рассмотрим графики скорости обучения. Поэтому мы рассматриваем обобщение правил обновления, где ε больше не является фиксированным, а может меняться со временем, и обозначаем ε_t как скорость обучения на временном шаге t. Для краткости, пусть:

Предполагая, что v_0 = 0, оригинальная формулировка становится:

а формулировка PyTorch становится:

В оригинальной формулировке (6), если скорость обучения изменится в момент времени t, то только величина терма при i = t в сумме будет затронута, и величины всех остальных термов останутся неизменными. В результате, немедленное влияние изменения скорости обучения довольно ограничено, и нам придется ждать, чтобы изменение скорости обучения “протекло” в последующие шаги времени, чтобы оказать более сильное влияние на общий размер шага. В отличие от этого, в формулировке PyTorch (7), если скорость обучения изменится в момент времени t, то величина всего шага будет немедленно затронута.

Для v_0 = 0, из расширенных правил ясно, что второе различие в конечном итоге не имеет никакого эффекта; в любой формулировке шаг вычисляется как дисконтированная сумма градиентов, которая вычитается из текущих параметров.

Основные различия

Игнорируя децимацию весов и затухание, анализируя алгоритм SGD в документации PyTorch, мы видим, что реализованные правила обновления следующие:

где θ’_{t+1} – параметры модели в момент времени t и

Мы будем называть уравнения 3 и 4 “заметкой” PyTorch, а уравнения 8 и 9 – “реализацией” PyTorch. Мы различаем θ и θ’ по причине, которая станет ясной вскоре. Самое существенное отличие от заметочной формулировки заключается в том, что градиент вычисляется на текущих параметрах, а не на смещенных параметрах. Из этого одного может показаться, что правила обновления, реализуемые алгоритмом, не являются правильной реализацией импульса Нестерова.

Теперь мы рассмотрим, как алгоритм PyTorch в конечном итоге аппроксимирует импульс Нестерова. Производные для более старой версии PyTorch можно найти здесь в работе Иво Данихелки, ссылка на которую дана в этой проблеме на GitHub. Производные для текущей версии PyTorch можно найти здесь, это относительно простое изменение предыдущих производных. Мы предоставляем здесь визуализацию этих (переизведенных) производных в LaTeX для удобства читателя. Реализованная формулировка получается простым изменением переменных. Конкретно, мы позволяем:

Сразу становится ясно, что правило обновления заметки для v_{t+1} (3) становится эквивалентным правилу обновления реализации для v_{t+1} (8) после изменения переменных. Мы хотим получить правило обновления для θ’_{t+1} в терминах θ’_t:

Именно это правило обновления мы видим, что реализовано в PyTorch (9). На высоком уровне реализация PyTorch предполагает, что текущие параметры θ’_t уже являются смещенной версией “фактических” параметров θ_t. Таким образом, на каждом шаге времени “фактические” параметры θ_t связаны с текущими параметрами θ’_t следующим образом:

Однако из исходного кода видно, что реализация SGD в PyTorch не делает никаких корректировок в конце алгоритма для получения конечных “фактических” параметров, поэтому конечный результат технически является приближением “фактических” параметров.

Наконец, мы теперь покажем, что v_0 должно быть равно 0:

Более того, мы можем подтвердить, что первое обновление “фактических” параметров совпадает с первым обновлением, сделанным в исходной формулировке, когда v_0 = 0:

Мы видим, что это эквивалентно уравнению 5.

Преимущества реализованной формулировки

Конечно, большой оставшийся вопрос: Почему вообще PyTorch беспокоится переформулировать импульс Нестерова из уравнений 3 и 4 в уравнения 8 и 9? Одно возможное объяснение заключается в том, что реформулировка может привести к сокращению количества арифметических операций, требуемых для выполнения. Чтобы оценить это возможное объяснение, давайте посчитаем количество арифметических операций. Для заметочной формулировки (3, 4) у нас есть:

Здесь всего семь операций. Для реализованной формулировки (8, 9) у нас есть:

Здесь всего шесть операций. Второй градиент в реализации PyTorch просто использует сохраненный результат от вычисления первого градиента, поэтому на каждом шаге выполняется только одно вычисление градиента. Таким образом, одно очевидное преимущество заключается в том, что реализация PyTorch сокращает количество дополнительных операций умножения на каждом шаге.

Вывод

В заключение:

  1. Правила обновления, указанные в заметке документации SGD PyTorch (3, 4), имеют разное расположение для скорости обучения по сравнению с оригинальными правилами обновления импульса Нестерова (1, 2). Это позволяет расписанию скорости обучения немедленно влиять на общий размер шага, в то время как в оригинальной формулировке изменения скорости обучения “протекали” через последующие временные шаги.
  2. Правила обновления, реализованные в алгоритме SGD PyTorch (8, 9), являются приближением правил обновления, указанных в заметке документации (3, 4), после простого изменения переменных. Хотя “фактические” параметры легко восстанавливаются из текущих параметров на каждом шаге времени, реализация PyTorch не делает такой корректировки в конце алгоритма, поэтому конечные параметры технически остаются приближением “фактических” конечных параметров.
  3. Очевидное преимущество реализации PyTorch заключается в том, что она избегает дополнительной операции умножения на каждом шаге времени.

Литература

  1. «SGD». SGD — PyTorch 2.0 Документация, pytorch.org/docs/stable/generated/torch.optim.SGD.html. Доступно 2 сентября 2023 г.
  2. Sutskever, Ilya и др. «О важности инициализации и момента в глубоком обучении». Международная конференция по машинному обучению. PMLR, 2013.
  3. Danihelka, Ivo. «Простое объяснение метода Nesterov’s Momentum». 25 августа 2012 г.
  4. Chintala, Soumith. «nesterov momentum is wrong in sgd · Issue #27 · torch/optim». GitHub, 13 октября 2014 г., github.com/torch/optim/issues/27.
  5. Gross, Sam. «Добавить примечание в документации о формулировке момента, используемой в optim · Issue #1099 · pytorch/pytorch». GitHub, 25 марта 2017 г., github.com/pytorch/pytorch/issues/1099#issuecomment-289190614.
  6. Zhao, Yilong. «Исправить ошибку в методе Nesterov Momentum · Issue #5920 · pytorch/pytorch». GitHub, 21 марта 2018 г., https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908.