Существенное ускорение стабильного распространения XL-вывода с помощью JAX на Cloud TPU v5e

Ускорение стабильного распространения XL-вывода с использованием JAX на Cloud TPU v5e

Генеративные модели искусственного интеллекта, такие как Stable Diffusion XL (SDXL), позволяют создавать высококачественный и реалистичный контент с широким спектром применения. Однако использование мощности таких моделей представляет существенные трудности и требует значительных вычислительных затрат. SDXL – это большая модель генерации изображений, компонент UNet которой примерно в три раза больше, чем у предыдущей версии модели. Развертывание такой модели в производственной среде вызывает сложности из-за увеличенных требований к памяти, а также увеличенного времени вывода. Сегодня мы с восторгом объявляем, что Diffusers от Hugging Face теперь поддерживает обслуживание SDXL с использованием JAX на Cloud TPUs, обеспечивая высокую производительность и экономичную оценку.

Облачные TPUs от Google – это специально разработанные ускорители искусственного интеллекта, оптимизированные для обучения и вывода больших моделей искусственного интеллекта, включая современные языковые модели и генеративные модели искусственного интеллекта, такие как SDXL. Новая Cloud TPU v5e создана для обеспечения экономичности и производительности, необходимых для масштабного обучения и вывода AI моделей.

🧨 Интеграция JAX Diffusers предлагает удобный способ запуска SDXL на TPU с помощью XLA, и мы создали демо-версию, чтобы продемонстрировать это. Вы можете попробовать его в этой области или внизу встроенной площадке:

Под капотом это демо-версия работает на нескольких экземплярах TPU v5e-4 (каждый экземпляр имеет 4 чипа TPU) и использует параллелизацию для обслуживания четырех больших изображений размером 1024×1024 примерно за 4 секунды. В это время включены преобразования формата, время коммуникации и обработка интерфейса; фактическое время генерации составляет около 2,3 секунды, как мы увидим ниже!

В этой записи блога:

  1. Мы объясняем, почему JAX + TPU + Diffusers является мощной платформой для запуска SDXL
  2. Мы объясняем, как можно создать простую конвейеризацию генерации изображений с помощью Diffusers и JAX
  3. Мы показываем результаты сравнения разных настроек TPU

Почему JAX + TPU v5e для SDXL?

Обслуживание SDXL с помощью JAX на Cloud TPU v5e с высокой производительностью и экономией средств становится возможным благодаря комбинации специализированного TPU оборудования и оптимизированного для производительности стека программного обеспечения. Ниже мы выделяем два ключевых фактора: компиляция JIT и параллелизм, основанный на компиляторе XLA с использованием JAX pmap.

Компиляция JIT

Замечательной особенностью JAX является его компиляция JIT. Компилятор JIT отслеживает код во время первого выполнения и генерирует высокооптимизированные двоичные файлы TPU, которые могут быть повторно использованы в последующих вызовах. Ограничение этого процесса заключается в том, что все формы ввода, промежуточные и выходные должны быть статическими, то есть известны заранее. Каждый раз, когда мы изменяем формы, запускается новый и дорогостоящий процесс компиляции. JIT-компиляция идеально подходит для сервисов, которые могут быть разработаны вокруг статических форм: компиляция выполняется один раз, а затем мы пользуемся сверхбыстрыми временами вывода.

Генерация изображений хорошо подходит для JIT-компиляции. Если мы всегда генерируем одинаковое количество изображений одного и того же размера, то формы вывода являются постоянными и известны заранее. Текстовые входы также постоянны: по своей природе Stable Diffusion и SDXL используют фиксированные векторы встраивания формы (с заполнением) для представления предложений, вводимых пользователем. Поэтому мы можем написать код JAX, который полагается на фиксированные формы и может быть сильно оптимизирован!

Высокая производительность при обработке больших пакетов

Рабочие нагрузки могут быть масштабированы на несколько устройств с помощью pmap JAX, который представляет одну программу множества данных (SPMD) программы. Применение pmap к функции позволяет скомпилировать функцию с помощью XLA, а затем выполнить ее параллельно на различных устройствах XLA. Для рабочих нагрузок по генерации текста в изображение это означает, что увеличение количества одновременно отрисовываемых изображений легко реализуется и не отрицательно сказывается на производительности. Например, запуск SDXL на TPU с 8 чипами будет генерировать 8 изображений за то же время, что и один чип для создания одного изображения.

TPU-экземпляры v5e поставляются в нескольких формах, включая формы с 1, 4 и 8 чипами, а также до 256 чипов (полный TPU-под v5e), с супербыстрыми связями ICI между чипами. Это позволяет выбрать форму TPU, которая лучше всего подходит для вашего случая использования, и легко воспользоваться параллельностью, которую обеспечивают JAX и TPU.

Как написать процесс генерации изображений в JAX

Мы пошагово рассмотрим код, который вам нужно написать, чтобы запустить вывод гораздо быстрее с помощью JAX! Сначала импортируем необходимые зависимости.

# Показать лучшие практики для SDXL JAXimport jaximport jax.numpy as jnpimport numpy as npfrom flax.jax_utils import replicatefrom diffusers import FlaxStableDiffusionXLPipelineimport time

Теперь мы загрузим базовую модель SDXL и остальные компоненты, необходимые для вывода. Окружение diffusers будет заниматься загрузкой и кэшированием всего для нас. Следуя функциональному подходу JAX, параметры модели возвращаются отдельно и должны быть переданы в трубопровод во время вывода:

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(    "stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True)

Параметры модели по умолчанию загружаются с точностью 32-бит. Чтобы сэкономить память и ускорить вычисления, мы преобразуем их в эффективное 16-битное представление bfloat16. Однако есть оговорка: для лучших результатов мы должны хранить состояние планировщика в формате float32, иначе накапливаются погрешности точности и получаются изображения низкого качества или даже черные изображения.

scheduler_state = params.pop("scheduler")params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)params["scheduler"] = scheduler_state

Теперь мы готовы настроить нашу подсказку и остальные входные данные в трубопроводе.

default_prompt = "фотография высокого качества дельфина-малыша, играющего в бассейне и носящего праздничную шляпу"default_neg_prompt = "иллюстрация, низкое качество"default_seed = 33default_guidance_scale = 5.0default_num_steps = 25

Подсказки должны быть предоставлены как тензоры в трубопровод, и они всегда должны иметь одинаковые размеры при каждом вызове. Это позволяет скомпилировать вызов вывода. Метод prepare_inputs в трубопроводе выполняет все необходимые шаги за нас, поэтому мы создадим вспомогательную функцию для подготовки нашей подсказки и отрицательной подсказки в виде тензоров. Мы будем использовать ее позже из нашей функции generate:

def tokenize_prompt(prompt, neg_prompt):    prompt_ids = pipeline.prepare_inputs(prompt)    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)    return prompt_ids, neg_prompt_ids

Для использования параллелизации мы будем реплицировать входные данные на устройствах. Cloud TPU v5e-4 имеет 4 чипа, поэтому, реплицируя входные данные, мы получаем каждый чип для генерации разного изображения параллельно. Мы должны осторожно предоставить каждому чипу разное случайное семя, чтобы получить 4 разных изображения:

NUM_DEVICES = jax.device_count()# Параметры модели не меняются во время вывода,# поэтому их нужно реплицировать только один раз.p_params = replicate(params)def replicate_all(prompt_ids, neg_prompt_ids, seed):    p_prompt_ids = replicate(prompt_ids)    p_neg_prompt_ids = replicate(neg_prompt_ids)    rng = jax.random.PRNGKey(seed)    rng = jax.random.split(rng, NUM_DEVICES)    return p_prompt_ids, p_neg_prompt_ids, rng

Теперь мы готовы объединить все это в функции генерации:

def generate(    prompt,    negative_prompt,    seed=default_seed,    guidance_scale=default_guidance_scale,    num_inference_steps=default_num_steps,):    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)    images = pipeline(        prompt_ids,        p_params,        rng,        num_inference_steps=num_inference_steps,        neg_prompt_ids=neg_prompt_ids,        guidance_scale=guidance_scale,        jit=True,    ).images    # преобразование изображений в PIL    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])    return pipeline.numpy_to_pil(np.array(images))

jit=True указывает, что мы хотим, чтобы вызов pipeline был скомпилирован. Это произойдет при первом вызове generate, и это будет очень медленно – JAX должен проследить операции, оптимизировать их и преобразовать их в низкоуровневые примитивы. В первый раз это заняло около трех минут.

start = time.time()print(f"Компилируется...")generate(default_prompt, default_neg_prompt)print(f"Скомпилировано за {time.time() - start}")

Это заняло около 3 минут при первом запуске. Но после того, как код был скомпилирован, инференция будет выполняться очень быстро. Давайте попробуем еще раз!

start = time.time()prompt = "lama в Древней Греции, масло на холсте"neg_prompt = "мультфильм, иллюстрация, анимация"images = generate(prompt, neg_prompt)print(f"Инференция заняла {time.time() - start}")

Теперь генерация 4 изображений заняла около 2 секунд!

Тестовое сравнение

Следующие значения получены при запуске SDXL 1.0 base в течение 20 шагов с использованием расписания Euler Discrete по умолчанию. Мы сравниваем Cloud TPU v5e с TPUv4 для тех же размеров пакета. Обратите внимание, что из-за параллелизма TPU v5e-4, как в нашем демо, сгенерирует 4 изображения при использовании размера пакета 1 (или 8 изображений при размере пакета 2). Аналогично, TPU v5e-8 сгенерирует 8 изображений при использовании размера пакета 1.

Cloud TPU тестировался с использованием Python 3.10 и версией jax 0.4.16. Эти же характеристики использовались в нашем демо Space.

TPU v5e дает до 2,4 раза большую производительность/$ на SDXL по сравнению с TPU v4, что демонстрирует эффективность тактовой частоты самого последнего поколения TPU.

Для измерения производительности инференции мы используем стандартную отраслевую метрику пропускной способности. Сначала мы измеряем задержку на каждое изображение, когда модель была скомпилирована и загружена. Затем мы вычисляем пропускную способность, разделив размер пакета на задержку каждого чипа. Таким образом, пропускная способность измеряет, как модель работает в производственных средах, независимо от количества используемых чипов. Затем мы делим пропускную способность на цену, чтобы получить производительность за доллар.

Как работает демо?

Демо, которое мы показали ранее, было создано с помощью скрипта, который в основном следует за кодом, который мы опубликовали в этом блоге. Он работает на нескольких устройствах Cloud TPU v5e с 4 чипами каждое, и есть простой балансировщик нагрузки, который случайным образом направляет запросы пользователей на бэкэнд-серверы. Когда вы вводите запрос в демо, ваш запрос будет назначен одному из бэкэнд-серверов, и вы получите 4 сгенерированных изображения.

Это простое решение, основанное на нескольких предварительно выделенных TPU-устройствах. В нашей будущей публикации мы рассмотрим, как создавать динамические решения, которые адаптируются к нагрузке с использованием GKE.

Весь код для демо-версии является открытым и доступен в Hugging Face Diffusers сегодня. Мы с нетерпением ждем, что вы создадите с Diffusers + JAX + Cloud TPUs!