Создание анимации градиентного спуска на Python

Изучение анимации градиентного спуска с использованием Python

Как построить траекторию точки на комплексной поверхности

Фото Todd Diemer на Unsplash

Позвольте рассказать вам, как я создал анимацию градиентного спуска, чтобы проиллюстрировать определенную точку в блог-посте. Это стоило того, так как я научился больше Python, делая это, и открыл новое умение: создание анимированных графиков.

Анимация градиентного спуска, созданная на Python. Изображение автора.

Я расскажу вам о шагах процесса, которые я следовал.

Немного о предыстории

Несколько дней назад я опубликовал блог-пост о градиентном спуске как оптимизационном алгоритме, используемом для обучения искусственных нейронных сетей.

Я хотел включить анимированную фигуру, чтобы показать, как выбор разных начальных точек для оптимизации градиентным спуском может давать разные результаты.

Именно тогда я наткнулся на эти потрясающие анимации, созданные Alec Radford несколько лет назад и опубликованные в комментарии Reddit, иллюстрирующем разницу между некоторыми продвинутыми алгоритмами градиентного спуска, такими как Adagrad, Adadelta и RMSprop.

Поскольку я стремлюсь заменить Matlab на Python, я решил попробовать создать подобную анимацию самостоятельно, используя “ванильный” градиентный спуск в качестве отправной точки.

Поехали, пошагово.

Построение поверхности для оптимизации

Первое, что мы делаем, это импортируем необходимые библиотеки и определяем математическую функцию, которую мы хотим представить.

Я хотел использовать поверхность с седловой точкой, поэтому я определил следующее уравнение:

Мы также создаем сетку точек для построения нашей поверхности. np.mgrid отлично подходит для этого. Комплексное число 81j, переданное как шаг длины, указывает, сколько точек нужно создать между начальным и конечным значениями (81 точка).

import numpy as npfrom mpl_toolkits.mplot3d import Axes3Dimport matplotlib.pyplot as plt# Создание функции для вычисления поверхностиdef f(theta):  x = theta[0]  y = theta[1]  return x**2 - y**2# Создание сетки точек для построенияx, y = np.mgrid[-1:1:81j, -1:1:81j]