Настройте модели Whisper на Amazon SageMaker с помощью LoRA

Программируйте модели Whisper на Amazon SageMaker с помощью LoRA

Whisper – это модель распознавания речи (ASR), которая была обучена на основе 680,000 часов надзорных данных из Интернета, охватывающих различные языки и задачи. Одним из её ограничений является низкая производительность на языках с небольшими ресурсами, таких как язык Маратхи и дравидийские языки, что может быть исправлено с помощью настройки. Однако, настройка модели Whisper стала значительной проблемой, как с точки зрения вычислительных ресурсов, так и требований к хранению. Пять до десяти запусков полной настройки модели Whisper требуют около 100 часов A100 GPU (40 GB SXM4) (варьируется в зависимости от размеров модели и параметров), и каждый обработанный проверочный точечный файл требует около 7 ГБ места на диске. Это сочетание высоких вычислительных и хранилищных требований может представлять серьезные препятствия, особенно в средах с ограниченными ресурсами, что часто делает достижение значимых результатов чрезвычайно сложным.

Адаптация с низким рангом, также известным как LoRA, использует уникальный подход к настройке модели. Она поддерживает предварительно обученные веса модели в статическом состоянии и вводит обучаемые матрицы декомпозиции ранга в каждый слой структуры Transformer. Этот метод может уменьшить количество обучаемых параметров, необходимых для последующих задач в 10,000 раз и уменьшить требования к памяти GPU в 3 раза. В терминах качества модели, LoRA показывает сопоставимые или даже превосходящие результаты по сравнению с традиционными методами настройки модели, при этом оперируя с меньшим количеством обучаемых параметров (см. результаты из оригинальной статьи о LoRA). Он также обеспечивает преимущество увеличенной скорости обучения. В отличие от методов с использованием адаптеров, LoRA не вводит дополнительную задержку во время вывода, тем самым сохраняя эффективность модели во время развертывания. Настройка Whisper с использованием LoRA показала многообещающие результаты. Например, Whisper-Large-v2: выполнение 3 эпох с набором данных голоса Common Voice на видеокарте с 8 ГБ памяти занимает 6-8 часов, что в 5 раз быстрее, чем полная настройка с сопоставимой производительностью.

Amazon SageMaker – это идеальная платформа для реализации настройки модели Whisper с помощью LoRA. Amazon SageMaker позволяет создавать, обучать и развертывать модели машинного обучения для любых случаев использования с полностью управляемой инфраструктурой, инструментами и рабочими процессами. Дополнительные преимущества обучения модели могут включать более низкие затраты на обучение с помощью Managed Spot Training, распределенные библиотеки обучения для разделения моделей и наборов данных на графических процессорных инстансах AWS, и многое другое. Обученные модели SageMaker могут быть легко развернуты для вывода непосредственно на SageMaker. В этом посте мы представляем пошаговое руководство по реализации настройки модели с использованием LoRA в SageMaker. Исходный код, связанный с этой реализацией, можно найти на GitHub.

Подготовка набора данных для настройки модели

Мы используем язык с ограниченными ресурсами – Маратхи для задачи настройки модели. С использованием библиотеки datasets от Hugging Face, вы можете загрузить и разделить набор данных Common Voice на тренировочный и тестовый наборы данных. Вот пример кода:

from datasets import load_dataset, DatasetDictlanguage = "Marathi"language_abbr = "mr"task = "transcribe"dataset_name = "mozilla-foundation/common_voice_11_0"common_voice = DatasetDict()common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", use_auth_token=True)common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", use_auth_token=True)

Модель распознавания речи Whisper требует, чтобы аудио-файлы были WAV-файлами с частотой дискретизации 16 кГц, моно и с 16-битными знаковыми целыми числами. Поскольку набор данных Common Voice имеет частоту дискретизации 48кГц, вам необходимо сначала выполнить поддискретизацию аудио-файлов. Затем вам нужно применить экстрактор функций Whisper, чтобы извлечь лог-мел-спектрограммные признаки из аудио, и применить токенизатор Whisper, чтобы преобразовать каждое предложение в транскрипте в токенизированный идентификатор. Вот пример кода:

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

# apply the data preparation function to all of our fine-tuning dataset samples using dataset's .map method
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)
common_voice.save_to_disk("marathi-common-voice-processed")
!aws s3 cp --recursive "marathi-common-voice-processed" s3://

После обработки всех образцов обучения загрузите обработанные данные в Amazon S3, чтобы при использовании обработанных данных обучения на этапе тонкой настройки можно было использовать FastFile для прямого монтирования файла S3, а не копирования его на локальный диск:

from sagemaker.inputs import TrainingInput

training_input_path=s3uritraining = TrainingInput(s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
s3_data=training_input_path,
distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key
input_mode='FastFile')

Тренировка модели

Для демонстрации мы используем pre-trained модель whisper-large-v2 (сейчас доступен whisper v3), которую можно импортировать через библиотеку Hugging Face transformers. Вы можете использовать 8-битную квантизацию для дальнейшего улучшения эффективности обучения. 8-битная квантизация предлагает оптимизацию памяти, округляя значение с плавающей точкой до 8-битных целых чисел. Это часто используемая техника сжатия модели, которая позволяет сохранить память при инференсе, минимально жертвуя точностью.

Для загрузки pre-trained модели в формате 8-битной квантизации просто добавьте аргумент load_in_8bit=True при создании модели, как показано в следующем коде. Таким образом модель будет загружена с весами, квантизированными до 8 бит, что уменьшит потребление памяти.

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

Мы используем реализацию LoRA из пакета Hugging Face’s peft. Существуют четыре шага для тонкой настройки модели с использованием LoRA:

  1. Создание базовой модели (как мы делали на предыдущем шаге).
  2. Создание конфигурации (LoraConfig), в которой определены параметры, специфичные для LoRA.
  3. Обертывание базовой модели с помощью get_peft_model(), чтобы получить обучаемую модель PeftModel.
  4. Тренировка модели PeftModel в качестве базовой модели.

См. следующий код:

from peft import LoraConfig, get_peft_model

config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)

training_args = Seq2SeqTrainingArguments(output_dir=args.model_dir,
per_device_train_batch_size=int(args.train_batch_size),
gradient_accumulation_steps=1,
learning_rate=float(args.learning_rate),
warmup_steps=args.warmup_steps,
num_train_epochs=args.num_train_epochs,
evaluation_strategy="epoch",
fp16=True,
per_device_eval_batch_size=args.eval_batch_size,
generation_max_length=128,
logging_steps=25,
remove_unused_columns=False,
label_names=["labels"])

trainer = Seq2SeqTrainer(args=training_args,
model=model,
train_dataset=train_dataset["train"],
eval_dataset=train_dataset.get("test", train_dataset["test"]),
data_collator=data_collator,
tokenizer=processor.feature_extractor)

Чтобы запустить задание обучения SageMaker, мы используем собственный Docker-контейнер. Вы можете скачать Docker-образ с GitHub, где ffmpeg4 и git-lfs упакованы вместе с другими требованиями Python. Чтобы узнать больше о том, как адаптировать свой собственный Docker-контейнер для работы с SageMaker, обратитесь к статье Адаптация вашего собственного обучающего контейнера. Затем вы можете использовать Hugging Face Estimator и запустить задание обучения SageMaker:

OUTPUT_PATH = f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'
huggingface_estimator = HuggingFace(entry_point='train.sh', source_dir='./src', output_path=OUTPUT_PATH, instance_type=instance_type, instance_count=1, # transformers_version='4.17.0', # pytorch_version='1.10.2', py_version='py310', image_uri=<ECR-PATH>, role=ROLE, metric_definitions=metric_definitions, volume_size=200, distribution=distribution, keep_alive_period_in_seconds=1800, environment=environment,)
huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)

Внедрение LoRA позволило нам запустить задачу Whisper по крупнонастраиваемому обучению на одном экземпляре GPU (например, ml.g5.2xlarge). В сравнении с полной крупнонастраиваемой задачей Whisper требуются несколько GPU (например, ml.p4d.24xlarge) и значительно большее время обучения. Более конкретно, наш эксперимент показал, что для выполнения полной крупнонастраиваемой задачи требуется в 24 раза больше часов работы GPU по сравнению с подходом LoRA.

Оценка производительности модели

Для оценки производительности модели Whisper после крупнонастраиваемого обучения мы вычисляем процент ошибок по словам (WER) на тестовом наборе данных. WER измеряет разницу между предсказанной транскрипцией и фактической транскрипцией. Меньшее значение WER указывает на более высокую производительность. Вы можете запустить следующий скрипт на предобученной модели и модели после крупнонастраиваемого обучения и сравнить разницу в значениях WER:

metric = evaluate.load("wer")
eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)
model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
  with torch.cuda.amp.autocast():
    with torch.no_grad():
      generated_tokens = (model.generate(input_features=batch["input_features"].to("cuda"), decoder_input_ids=batch["labels"][:, :4].to("cuda"), max_new_tokens=255,).cpu().numpy())
      labels = batch["labels"].cpu().numpy()
      labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
      decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
      decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
      metric.add_batch(predictions=decoded_preds, references=decoded_labels,)
      del generated_tokens, labels, batch
  gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")

Заключение

В этой статье мы продемонстрировали крупнонастраиваемое обучение модели Whisper, передовой модели распознавания речи. В частности, мы использовали LoRA от Hugging Face и включили 8-битную квантизацию для эффективного обучения. Мы также показали, как запустить задачу обучения в SageMaker.

Хотя это важный первый шаг, существует несколько способов улучшить модель Whisper, продолжить работу над этим. В будущем рассмотрите возможность использования распределенного обучения в SageMaker для масштабирования обучения на гораздо большем наборе данных. Это позволит модели обучаться на более разнообразных и всесторонних данных, улучшая точность. Вы также можете оптимизировать задержку при обслуживании модели Whisper для обеспечения распознавания речи в реальном времени. Кроме того, вы можете расширить работу с обработкой более длинных аудио-транскрипций, что потребует изменений в архитектуре модели и схемах обучения.

Признание

Авторы выражают благодарность Парас Мехре, Джону Солу и Эвандро Франко за их глубокие замечания и рецензию на статью.