Настройте модели Whisper на Amazon SageMaker с помощью LoRA
Программируйте модели Whisper на Amazon SageMaker с помощью LoRA
Whisper – это модель распознавания речи (ASR), которая была обучена на основе 680,000 часов надзорных данных из Интернета, охватывающих различные языки и задачи. Одним из её ограничений является низкая производительность на языках с небольшими ресурсами, таких как язык Маратхи и дравидийские языки, что может быть исправлено с помощью настройки. Однако, настройка модели Whisper стала значительной проблемой, как с точки зрения вычислительных ресурсов, так и требований к хранению. Пять до десяти запусков полной настройки модели Whisper требуют около 100 часов A100 GPU (40 GB SXM4) (варьируется в зависимости от размеров модели и параметров), и каждый обработанный проверочный точечный файл требует около 7 ГБ места на диске. Это сочетание высоких вычислительных и хранилищных требований может представлять серьезные препятствия, особенно в средах с ограниченными ресурсами, что часто делает достижение значимых результатов чрезвычайно сложным.
Адаптация с низким рангом, также известным как LoRA, использует уникальный подход к настройке модели. Она поддерживает предварительно обученные веса модели в статическом состоянии и вводит обучаемые матрицы декомпозиции ранга в каждый слой структуры Transformer. Этот метод может уменьшить количество обучаемых параметров, необходимых для последующих задач в 10,000 раз и уменьшить требования к памяти GPU в 3 раза. В терминах качества модели, LoRA показывает сопоставимые или даже превосходящие результаты по сравнению с традиционными методами настройки модели, при этом оперируя с меньшим количеством обучаемых параметров (см. результаты из оригинальной статьи о LoRA). Он также обеспечивает преимущество увеличенной скорости обучения. В отличие от методов с использованием адаптеров, LoRA не вводит дополнительную задержку во время вывода, тем самым сохраняя эффективность модели во время развертывания. Настройка Whisper с использованием LoRA показала многообещающие результаты. Например, Whisper-Large-v2: выполнение 3 эпох с набором данных голоса Common Voice на видеокарте с 8 ГБ памяти занимает 6-8 часов, что в 5 раз быстрее, чем полная настройка с сопоставимой производительностью.
Amazon SageMaker – это идеальная платформа для реализации настройки модели Whisper с помощью LoRA. Amazon SageMaker позволяет создавать, обучать и развертывать модели машинного обучения для любых случаев использования с полностью управляемой инфраструктурой, инструментами и рабочими процессами. Дополнительные преимущества обучения модели могут включать более низкие затраты на обучение с помощью Managed Spot Training, распределенные библиотеки обучения для разделения моделей и наборов данных на графических процессорных инстансах AWS, и многое другое. Обученные модели SageMaker могут быть легко развернуты для вывода непосредственно на SageMaker. В этом посте мы представляем пошаговое руководство по реализации настройки модели с использованием LoRA в SageMaker. Исходный код, связанный с этой реализацией, можно найти на GitHub.
Подготовка набора данных для настройки модели
Мы используем язык с ограниченными ресурсами – Маратхи для задачи настройки модели. С использованием библиотеки datasets от Hugging Face, вы можете загрузить и разделить набор данных Common Voice на тренировочный и тестовый наборы данных. Вот пример кода:
- Три пути, которыми генеративный искусственный интеллект может усилить кибербезопасность
- Оптимизация генеративного искусственного интеллекта с усилением восстановления архитектура, алгоритмы и обзор применений
- Эта статья по искусственному интеллекту представляет GLaMM (Grounding Large Multimodal Model) многослойную модель, обученную от начала до конца и оснащенную возможностью визуальной фиксации с гибкостью работы с изображениями и регионами ввода.
from datasets import load_dataset, DatasetDictlanguage = "Marathi"language_abbr = "mr"task = "transcribe"dataset_name = "mozilla-foundation/common_voice_11_0"common_voice = DatasetDict()common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", use_auth_token=True)common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", use_auth_token=True)
Модель распознавания речи Whisper требует, чтобы аудио-файлы были WAV-файлами с частотой дискретизации 16 кГц, моно и с 16-битными знаковыми целыми числами. Поскольку набор данных Common Voice имеет частоту дискретизации 48кГц, вам необходимо сначала выполнить поддискретизацию аудио-файлов. Затем вам нужно применить экстрактор функций Whisper, чтобы извлечь лог-мел-спектрограммные признаки из аудио, и применить токенизатор Whisper, чтобы преобразовать каждое предложение в транскрипте в токенизированный идентификатор. Вот пример кода:
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
# apply the data preparation function to all of our fine-tuning dataset samples using dataset's .map method
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)
common_voice.save_to_disk("marathi-common-voice-processed")
!aws s3 cp --recursive "marathi-common-voice-processed" s3://
После обработки всех образцов обучения загрузите обработанные данные в Amazon S3, чтобы при использовании обработанных данных обучения на этапе тонкой настройки можно было использовать FastFile для прямого монтирования файла S3, а не копирования его на локальный диск:
from sagemaker.inputs import TrainingInput
training_input_path=s3uritraining = TrainingInput(s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
s3_data=training_input_path,
distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key
input_mode='FastFile')
Тренировка модели
Для демонстрации мы используем pre-trained модель whisper-large-v2 (сейчас доступен whisper v3), которую можно импортировать через библиотеку Hugging Face transformers. Вы можете использовать 8-битную квантизацию для дальнейшего улучшения эффективности обучения. 8-битная квантизация предлагает оптимизацию памяти, округляя значение с плавающей точкой до 8-битных целых чисел. Это часто используемая техника сжатия модели, которая позволяет сохранить память при инференсе, минимально жертвуя точностью.
Для загрузки pre-trained модели в формате 8-битной квантизации просто добавьте аргумент load_in_8bit=True при создании модели, как показано в следующем коде. Таким образом модель будет загружена с весами, квантизированными до 8 бит, что уменьшит потребление памяти.
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
Мы используем реализацию LoRA из пакета Hugging Face’s peft. Существуют четыре шага для тонкой настройки модели с использованием LoRA:
- Создание базовой модели (как мы делали на предыдущем шаге).
- Создание конфигурации (LoraConfig), в которой определены параметры, специфичные для LoRA.
- Обертывание базовой модели с помощью get_peft_model(), чтобы получить обучаемую модель PeftModel.
- Тренировка модели PeftModel в качестве базовой модели.
См. следующий код:
from peft import LoraConfig, get_peft_model
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)
training_args = Seq2SeqTrainingArguments(output_dir=args.model_dir,
per_device_train_batch_size=int(args.train_batch_size),
gradient_accumulation_steps=1,
learning_rate=float(args.learning_rate),
warmup_steps=args.warmup_steps,
num_train_epochs=args.num_train_epochs,
evaluation_strategy="epoch",
fp16=True,
per_device_eval_batch_size=args.eval_batch_size,
generation_max_length=128,
logging_steps=25,
remove_unused_columns=False,
label_names=["labels"])
trainer = Seq2SeqTrainer(args=training_args,
model=model,
train_dataset=train_dataset["train"],
eval_dataset=train_dataset.get("test", train_dataset["test"]),
data_collator=data_collator,
tokenizer=processor.feature_extractor)
Чтобы запустить задание обучения SageMaker, мы используем собственный Docker-контейнер. Вы можете скачать Docker-образ с GitHub, где ffmpeg4 и git-lfs упакованы вместе с другими требованиями Python. Чтобы узнать больше о том, как адаптировать свой собственный Docker-контейнер для работы с SageMaker, обратитесь к статье Адаптация вашего собственного обучающего контейнера. Затем вы можете использовать Hugging Face Estimator и запустить задание обучения SageMaker:
OUTPUT_PATH = f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/'
huggingface_estimator = HuggingFace(entry_point='train.sh', source_dir='./src', output_path=OUTPUT_PATH, instance_type=instance_type, instance_count=1, # transformers_version='4.17.0', # pytorch_version='1.10.2', py_version='py310', image_uri=<ECR-PATH>, role=ROLE, metric_definitions=metric_definitions, volume_size=200, distribution=distribution, keep_alive_period_in_seconds=1800, environment=environment,)
huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)
Внедрение LoRA позволило нам запустить задачу Whisper по крупнонастраиваемому обучению на одном экземпляре GPU (например, ml.g5.2xlarge). В сравнении с полной крупнонастраиваемой задачей Whisper требуются несколько GPU (например, ml.p4d.24xlarge) и значительно большее время обучения. Более конкретно, наш эксперимент показал, что для выполнения полной крупнонастраиваемой задачи требуется в 24 раза больше часов работы GPU по сравнению с подходом LoRA.
Оценка производительности модели
Для оценки производительности модели Whisper после крупнонастраиваемого обучения мы вычисляем процент ошибок по словам (WER) на тестовом наборе данных. WER измеряет разницу между предсказанной транскрипцией и фактической транскрипцией. Меньшее значение WER указывает на более высокую производительность. Вы можете запустить следующий скрипт на предобученной модели и модели после крупнонастраиваемого обучения и сравнить разницу в значениях WER:
metric = evaluate.load("wer")
eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator)
model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.cuda.amp.autocast():
with torch.no_grad():
generated_tokens = (model.generate(input_features=batch["input_features"].to("cuda"), decoder_input_ids=batch["labels"][:, :4].to("cuda"), max_new_tokens=255,).cpu().numpy())
labels = batch["labels"].cpu().numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
metric.add_batch(predictions=decoded_preds, references=decoded_labels,)
del generated_tokens, labels, batch
gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")
Заключение
В этой статье мы продемонстрировали крупнонастраиваемое обучение модели Whisper, передовой модели распознавания речи. В частности, мы использовали LoRA от Hugging Face и включили 8-битную квантизацию для эффективного обучения. Мы также показали, как запустить задачу обучения в SageMaker.
Хотя это важный первый шаг, существует несколько способов улучшить модель Whisper, продолжить работу над этим. В будущем рассмотрите возможность использования распределенного обучения в SageMaker для масштабирования обучения на гораздо большем наборе данных. Это позволит модели обучаться на более разнообразных и всесторонних данных, улучшая точность. Вы также можете оптимизировать задержку при обслуживании модели Whisper для обеспечения распознавания речи в реальном времени. Кроме того, вы можете расширить работу с обработкой более длинных аудио-транскрипций, что потребует изменений в архитектуре модели и схемах обучения.
Признание
Авторы выражают благодарность Парас Мехре, Джону Солу и Эвандро Франко за их глубокие замечания и рецензию на статью.