Meta AI представила wav2vec2 XLS-R («XLS-R») в конце 2021 года. XLS-R — это модель машинного обучения («ML») для изучения межъязыковых речевых представлений; и оно было обучено на более чем 400 000 часов общедоступных речевых аудиозаписей на 128 языках. После своего выпуска модель представляла собой скачок по сравнению с межъязыковой моделью XLSR-53 от Meta AI, которая обучалась примерно на 50 000 часов речевого звука на 53 языках.
В этом руководстве описаны шаги по точной настройке XLS-R для автоматического распознавания речи («ASR») с помощью Kaggle Notebook . Модель будет настроена на чилийский испанский язык, но можно выполнить общие шаги, чтобы настроить XLS-R на другие языки, которые вы желаете.
Выполнение вывода на точно настроенной модели будет описано в сопутствующем руководстве, которое делает это руководство первой из двух частей. Я решил создать отдельное руководство, посвященное выводам, поскольку это руководство по точной настройке стало немного длинным.
Предполагается, что у вас есть опыт работы с машинным обучением и вы понимаете основные концепции ASR. Новичкам может быть трудно следовать/понимать этапы сборки.
Исходная модель wav2vec2, представленная в 2020 году, была предварительно обучена на 960 часах речевого звука набора данных Librispeech и примерно 53 200 часах речевого звука набора данных LibriVox . После выпуска были доступны модели двух размеров: БАЗОВАЯ модель с 95 миллионами параметров и БОЛЬШАЯ модель с 317 миллионами параметров.
XLS-R, с другой стороны, был предварительно обучен на многоязычном речевом звуке из 5 наборов данных:
Существует 3 модели XLS-R: XLS-R (0,3B) с 300 миллионами параметров, XLS-R (1B) с 1 миллиардом параметров и XLS-R (2B) с 2 миллиардами параметров. В этом руководстве будет использоваться модель XLS-R (0.3B).
Есть несколько замечательных статей о том, как точно настроить модели wav2vev2 , и, возможно, эта статья является своего рода «золотым стандартом». Конечно, общий подход здесь имитирует подход, который вы найдете в других руководствах. Вы будете:
Однако есть три ключевых различия между этим руководством и другими:
Для завершения руководства вам потребуется:
Прежде чем приступить к созданию блокнота, возможно, будет полезно просмотреть два подраздела ниже. Они описывают:
Как упоминалось во введении , модель XLS-R будет настроена на чилийский испанский язык. Конкретным набором данных является набор данных чилийской испанской речи, разработанный Геварой-Рукозом и др. Он доступен для скачивания на OpenSLR . Набор данных состоит из двух поднаборов данных: (1) 2636 аудиозаписей говорящих на чилийском языке мужчин и (2) 1738 аудиозаписей говорящих на чилийском языке женщин.
Каждый поднабор данных включает индексный файл line_index.tsv
. Каждая строка каждого индексного файла содержит пару имен аудиофайлов и транскрипцию аудио в связанном файле, например:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
Для удобства я загрузил набор данных чилийской испанской речи в Kaggle. Существует один набор данных Kaggle для записей чилийских мужчин и один набор данных Kaggle для записей чилийских женщин . Эти наборы данных Kaggle будут добавлены в блокнот Kaggle, который вы создадите, следуя инструкциям в этом руководстве.
WER — это один из показателей, который можно использовать для измерения производительности моделей автоматического распознавания речи. WER предоставляет механизм измерения того, насколько близок текстовый прогноз к текстовой ссылке. WER достигает этого, записывая ошибки трех типов:
замены ( S
): ошибка замены фиксируется, когда предсказание содержит слово, отличное от аналогичного слова в ссылке. Например, это происходит, когда прогноз неправильно пишет слово в ссылке.
удаления ( D
): ошибка удаления записывается, когда прогноз содержит слово, которого нет в ссылке.
вставки ( I
): ошибка вставки записывается, когда предсказание не содержит слова, присутствующего в ссылке.
Очевидно, что WER работает на уровне слов. Формула показателя WER выглядит следующим образом:
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
Простой пример WER на испанском языке выглядит следующим образом:
prediction: "Él está saliendo." reference: "Él está saltando."
Визуализировать ошибки прогноза помогает таблица:
ТЕКСТ | СЛОВО 1 | СЛОВО 2 | СЛОВО 3 |
---|---|---|---|
прогноз | Эль | это | сальендо |
ссылка | Эль | это | сальтандо |
| правильный | правильный | замена |
Прогноз содержит 1 ошибку замены, 0 ошибок удаления и 0 ошибок вставки. Итак, WER для этого примера:
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
Должно быть очевидно, что коэффициент ошибок в словах не обязательно говорит нам о том, какие именно ошибки существуют. В приведенном выше примере WER определяет, что СЛОВО 3 содержит ошибку в прогнозируемом тексте, но не сообщает нам, что символы i и e в прогнозе неверны. Другие показатели, такие как частота ошибок символов («CER»), можно использовать для более точного анализа ошибок.
Теперь вы готовы приступить к созданию блокнота для точной настройки.
Ваш блокнот Kaggle должен быть настроен для отправки данных тренировочного запуска в WandB с использованием вашего ключа API WandB. Для этого вам нужно его скопировать.
www.wandb.com
.www.wandb.ai/authorize
.
xls-r-300m-chilean-spanish-asr
.Секрет Kaggle будет использоваться для безопасного хранения вашего ключа API WandB.
WANDB_API_KEY
в поле «Метка» и введите ключ API WandB для этого значения.WANDB_API_KEY
установлен.Набор речевых данных чилийского испанского языка был загружен в Kaggle в виде двух отдельных наборов данных:
Добавьте оба этих набора данных в свой блокнот Kaggle.
Следующие 32 подэтапа строят каждую из 32 ячеек блокнота точной настройки по порядку.
Первая ячейка блокнота тонкой настройки устанавливает зависимости. Установите первую ячейку:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
до последней версии. torchaudio
будет использоваться для загрузки аудиофайлов и повторной выборки аудиоданных.jiwer
, который необходим для использования метода load_metric
библиотеки HuggingFace Datasets
, который будет использоваться позже.Вторая ячейка импортирует необходимые пакеты Python. Установите вторую ячейку:
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
HuggingFace и связанные с ней классы Wav2Vec2*
обеспечивают основу функциональности, используемой для точной настройки.Третья ячейка импортирует метрику оценки HuggingFace WER. Установите третью ячейку на:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
Четвертая ячейка извлекает ваш секрет WANDB_API_KEY
, который был установлен на шаге 2.2 . Установите четвертую ячейку:
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
Пятая ячейка задает константы, которые будут использоваться во всей записной книжке. Установите пятую ячейку на:
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
Шестая ячейка определяет служебные методы для чтения индексных файлов набора данных (см. подраздел «Набор обучающих данных» выше), а также для очистки текста транскрипции и создания словаря. Установите шестую ячейку на:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
Метод read_index_file_data
считывает индексный файл набора данных line_index.tsv
и создает список списков с именем аудиофайла и данными транскрипции, например:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
усекает данные файла индекса списка, используя константу NUM_LOAD_FROM_EACH_SET
, установленную на шаге 3.5 . В частности, константа NUM_LOAD_FROM_EACH_SET
используется для указания количества аудиосэмплов, которые должны быть загружены из каждого набора данных. Для целей данного руководства число установлено равным 1600
что означает, что в конечном итоге будет загружено в общей сложности 3200
аудиосэмплов. Чтобы загрузить все образцы, установите для NUM_LOAD_FROM_EACH_SET
строковое значение all
.clean_text
используется для удаления каждой текстовой транскрипции символов, указанных в регулярном выражении, присвоенном SPECIAL_CHARS
на шаге 3.5 . Эти символы, включая знаки препинания, можно исключить, поскольку они не несут никакой семантической ценности при обучении модели изучению сопоставлений между звуковыми функциями и транскрипцией текста.create_vocab
создает словарь из чистой текстовой транскрипции. Проще говоря, он извлекает все уникальные символы из набора очищенных текстовых транскрипций. Вы увидите пример сгенерированного словаря на шаге 3.14 . Седьмая ячейка определяет служебные методы, использующие torchaudio
для загрузки и повторной выборки аудиоданных. Установите седьмую ячейку:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
загружает указанный аудиофайл и возвращает многомерную матрицу torch.Tensor
аудиоданных вместе с частотой дискретизации аудио. Все аудиофайлы в обучающих данных имеют частоту дискретизации 48000
Гц. Эта «исходная» частота дискретизации фиксируется константой ORIG_SAMPLING_RATE
на шаге 3.5 .resample
используется для понижения частоты дискретизации аудиоданных с частоты дискретизации от 48000
до 16000
. wav2vec2 предварительно обучается на аудио, выбранном с частотой 16000
Гц. Соответственно, любой звук, используемый для точной настройки, должен иметь одинаковую частоту дискретизации. В этом случае примеры аудио необходимо уменьшить с 48000
Гц до 16000
Гц. 16000
Гц фиксируется константой TGT_SAMPLING_RATE
на шаге 3.5 .Восьмая ячейка определяет служебные методы, которые обрабатывают данные аудио и транскрипции. Установите восьмую ячейку:
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
возвращает входные значения из предоставленной обучающей выборки.process_target_text
кодирует каждую текстовую транскрипцию как список меток, то есть список индексов, ссылающихся на символы в словаре. Вы увидите образец кодировки на шаге 3.15 .Девятая ячейка является последней ячейкой служебного метода и содержит метод расчета частоты ошибок в словах между эталонной транскрипцией и прогнозируемой транскрипцией. Установите девятую ячейку на:
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
Десятая ячейка считывает индексные файлы обучающих данных для записей выступающих мужчин и записей говорящих женщин с использованием метода read_index_file_data
определенного на шаге 3.6 . Установите десятую ячейку:
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
Одиннадцатая ячейка усекает списки обучающих данных с помощью метода truncate_training_dataset
определенного на шаге 3.6 . Установите одиннадцатую ячейку на:
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
, установленная на шаге 3.5, определяет количество сохраняемых выборок из каждого набора данных. В этом руководстве для константы установлено значение 1600
, всего 3200
выборок.Двенадцатая ячейка объединяет усеченные списки обучающих данных. Установите двенадцатую ячейку:
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
Тринадцатая ячейка перебирает каждую выборку обучающих данных и очищает связанный текст транскрипции с помощью метода clean_text
определенного на шаге 3.6 . Установите тринадцатую ячейку на:
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
Четырнадцатая ячейка создает словарь, используя очищенные транскрипции из предыдущего шага и метод create_vocab
определенный в шаге 3.6 . Установите четырнадцатую ячейку на:
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
Словарь хранится в виде словаря с символами в качестве ключей и индексами словаря в качестве значений.
Вы можете напечатать vocab_dict
, который должен выдать следующий результат:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
Пятнадцатая ячейка добавляет символ-разделитель слов |
к словарю. Установите пятнадцатую ячейку на:
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
Символ-разделитель слов используется при токенизации транскрипции текста в виде списка меток. В частности, он используется для определения конца слова и при инициализации класса Wav2Vec2CTCTokenizer
, как будет показано в шаге 3.17 .
Например, следующий список кодирует no te entiendo nada
с использованием словаря из шага 3.14 :
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
Шестнадцатая ячейка записывает словарь в файл. Установите шестнадцатую ячейку на:
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
. Семнадцатая ячейка инициализирует экземпляр Wav2Vec2CTCTokenizer
. Установите семнадцатую ячейку на:
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
Токенизатор используется для кодирования текстовых транскрипций и декодирования списка меток обратно в текст.
Обратите внимание, что tokenizer
инициализируется с помощью [UNK]
, назначенного для unk_token
, и [PAD]
назначенного для pad_token
, причем первый используется для представления неизвестных токенов в текстовых транскрипциях, а второй используется для дополнения транскрипций при создании пакетов транскрипций разной длины. Эти два значения будут добавлены в словарь токенизатором.
Инициализация токенизатора на этом этапе также добавит в словарь два дополнительных токена, а именно <s>
и /</s>
, которые используются для разграничения начала и конца предложений соответственно.
|
на этом этапе явно присваивается word_delimiter_token
, чтобы отразить, что символ вертикальной черты будет использоваться для обозначения конца слов в соответствии с добавлением символа в словарь на шаге 3.15 . |
символ — это значение по умолчанию для word_delimiter_token
. Таким образом, его не нужно было задавать явно, но это было сделано для ясности.
Как и в случае с word_delimiter_token
, для replace_word_delimiter_char
явно назначается один пробел, что отражает тот факт, что символ вертикальной черты |
будет использоваться для замены пробелов в текстовых транскрипциях. Пустое пространство является значением по умолчанию для replace_word_delimiter_char
. Таким образом, его также не нужно было задавать явно, но это было сделано для ясности.
Вы можете распечатать полный словарь токенизатора, вызвав метод get_vocab()
в tokenizer
.
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
Восемнадцатая ячейка инициализирует экземпляр Wav2Vec2FeatureExtractor
. Установите восемнадцатую ячейку:
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
, являются значениями по умолчанию, за исключением return_attention_mask
, который по умолчанию имеет значение False
. Значения по умолчанию показаны/передаются для ясности.feature_size
определяет размерность входных объектов (т. е. функций аудиоданных). Значение этого параметра по умолчанию — 1
.sampling_rate
сообщает экстрактору функций частоту дискретизации, с которой аудиоданные должны быть оцифрованы. Как обсуждалось в шаге 3.7 , wav2vec2 предварительно обучается на аудио, выбранном с частотой 16000
Гц, и, следовательно, 16000
является значением по умолчанию для этого параметра.padding_value
указывает значение, которое используется при дополнении аудиоданных, что требуется при пакетной обработке аудиосэмплов различной длины. Значение по умолчанию — 0.0
.do_normalize
используется, чтобы указать, следует ли преобразовать входные данные к стандартному нормальному распределению. Значение по умолчанию True
. В документации класса Wav2Vec2FeatureExtractor
отмечается, что «[нормализация] может помочь значительно улучшить производительность некоторых моделей».return_attention_mask
указывают, следует ли передавать маску внимания или нет. Для этого варианта использования установлено значение True
. Девятнадцатая ячейка инициализирует экземпляр Wav2Vec2Processor
. Установите девятнадцатую ячейку на:
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
Класс Wav2Vec2Processor
объединяет tokenizer
и feature_extractor
из шагов 3.17 и 3.18 соответственно в один процессор.
Обратите внимание, что конфигурацию процессора можно сохранить, вызвав метод save_pretrained
в экземпляре класса Wav2Vec2Processor
.
processor.save_pretrained(OUTPUT_DIR_PATH)
Двадцатая ячейка загружает каждый аудиофайл, указанный в списке all_training_samples
. Установите двадцатую ячейку на:
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
и сохраняются в all_input_data
в виде списка словарей. Каждый словарь содержит аудиоданные для определенного образца, а также текстовую транскрипцию аудио.read_audio_data
также возвращает частоту дискретизации аудиоданных. Поскольку мы знаем, что в этом случае частота дискретизации для всех аудиофайлов составляет 48000
Гц, частота дискретизации на этом этапе игнорируется.all_input_data
в кадр данных Pandas Двадцать первая ячейка преобразует список all_input_data
в DataFrame Pandas, чтобы упростить манипулирование данными. Установите двадцать первую ячейку на:
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
Двадцать вторая ячейка использует processor
, инициализированный на шаге 3.19, для извлечения признаков из каждого образца аудиоданных и кодирования каждой текстовой транскрипции в виде списка меток. Установите двадцать вторую ячейку на:
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
Двадцать третья ячейка разбивает DataFrame all_input_data_df
на наборы данных обучения и оценки (проверки), используя константу SPLIT_PCT
из шага 3.5 . Установите двадцать третью ячейку на:
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
равно 0.10
что означает, что 10 % всех входных данных будут храниться для оценки, а 90 % данных будут использоваться для обучения/тонкой настройки.Dataset
Двадцать четвертая ячейка преобразует кадры данных train_data_df
и valid_data_df
в объекты Dataset
. Установите двадцать четвертую ячейку на:
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Объекты Dataset
используются экземплярами класса HuggingFace Trainer
, как вы увидите на шаге 3.30 .
Эти объекты содержат метаданные о наборе данных, а также сам набор данных.
Вы можете распечатать train_data
и valid_data
, чтобы просмотреть метаданные для обоих объектов Dataset
.
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
Двадцать пятая ячейка инициализирует предварительно обученную модель XLS-R (0,3). Установите двадцать пятую ячейку на:
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
from_pretrained
, вызываемый в Wav2Vec2ForCTC
указывает, что мы хотим загрузить предварительно обученные веса для указанной модели.MODEL
была указана на шаге 3.5 и имела значение facebook/wav2vec2-xls-r-300m
что соответствует модели XLS-R (0.3).ctc_loss_reduction
указывает тип сокращения, применяемого к выходным данным функции потерь коннекционистской временной классификации («CTC»). Потери CTC используются для расчета потерь между непрерывным вводом, в данном случае аудиоданными, и целевой последовательностью, в данном случае транскрипцией текста. Если установить значение mean
, выходные потери для пакета входов будут разделены на целевые длины. Затем рассчитывается среднее значение по партии, и уменьшение применяется к значениям потерь.pad_token_id
указывает токен, который будет использоваться для заполнения при пакетной обработке. Ему присваивается идентификатор [PAD]
, установленный при инициализации токенизатора на шаге 3.17 .vocab_size
определяет размер словаря модели. Это размер словаря после инициализации токенизатора на шаге 3.17 , который отражает количество узлов выходного уровня прямой части сети.Двадцать шестая ячейка замораживает предварительно обученные веса экстрактора признаков. Установите двадцать шестую ячейку на:
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
Двадцать седьмая ячейка инициализирует аргументы обучения, которые будут переданы экземпляру Trainer
. Установите двадцать седьмую ячейку на:
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
принимает более 100 параметров .save_safetensors
, имеющий значение False
указывает, что точно настроенная модель должна быть сохранена в файле pickle
вместо использования формата safetensors
.group_by_length
, когда True
указывает, что образцы примерно одинаковой длины должны быть сгруппированы вместе. Это сводит к минимуму заполнение и повышает эффективность тренировки.per_device_train_batch_size
устанавливает количество выборок на обучающий мини-пакет. Для этого параметра установлено значение 18
с помощью константы TRAIN_BATCH_SIZE
назначенной на шаге 3.5 . Это подразумевает 160 шагов за эпоху.per_device_eval_batch_size
устанавливает количество выборок на мини-пакет оценки (удержания). Для этого параметра установлено значение 10
с помощью константы EVAL_BATCH_SIZE
назначенной на шаге 3.5 .num_train_epochs
устанавливает количество эпох обучения. Для этого параметра установлено значение 30
с помощью константы TRAIN_EPOCHS
назначенной на шаге 3.5 . Это подразумевает 4800 общих шагов во время тренировки.gradient_checkpointing
, когда True
помогает экономить память путем расчета градиента контрольных точек, но приводит к более медленным обратным проходам.evaluation_strategy
установлено значение steps
это означает, что оценка будет выполняться и регистрироваться во время обучения с интервалом, указанным параметром eval_steps
.logging_strategy
установлено значение steps
это означает, что статистика обучающего запуска будет записываться с интервалом, указанным параметром logging_steps
.save_strategy
установлено steps
это означает, что контрольная точка точно настроенной модели будет сохраняться с интервалом, указанным параметром save_steps
.eval_steps
устанавливает количество шагов между оценками контрольных данных. Для этого параметра установлено значение 100
с помощью константы EVAL_STEPS
, назначенной на шаге 3.5 .save_steps
устанавливает количество шагов, после которых сохраняется контрольная точка точно настроенной модели. Для этого параметра установлено значение 3200
с помощью константы SAVE_STEPS
, назначенной на шаге 3.5 .logging_steps
устанавливает количество шагов между журналами статистики обучающего запуска. Для этого параметра установлено значение 100
с помощью константы LOGGING_STEPS
, назначенной на шаге 3.5 .learning_rate
устанавливает начальную скорость обучения. Этот параметр имеет значение 1e-4
с помощью константы LEARNING_RATE
, назначенной на шаге 3.5 .warmup_steps
задает количество шагов для линейного повышения скорости обучения от 0 до значения, установленного learning_rate
. Этот параметр имеет значение 800
с помощью константы WARMUP_STEPS
, назначенной на шаге 3.5 .Двадцать восьмая ячейка определяет логику динамического заполнения входных и целевых последовательностей. Установите двадцать восьмую ячейку на:
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
, который будет инициализирован на мгновение на шаге 3.30 . Поскольку входные последовательности и последовательности меток различаются по длине в каждом мини-пакете, некоторые последовательности необходимо дополнить, чтобы все они имели одинаковую длину.DataCollatorCTCWithPadding
динамически дополняет мини-пакетные данные. Параметр padding
, если ему присвоено значение True
указывает, что более короткие последовательности функций аудиовхода и последовательности меток должны иметь ту же длину, что и самая длинная последовательность в мини-пакете.0.0
установленным при инициализации экстрактора функций на шаге 3.18 .-100
, поэтому эти метки игнорируются при вычислении метрики WER.Двадцать девятая ячейка инициализирует экземпляр средства сортировки данных, определенного на предыдущем шаге. Установите двадцать девятую ячейку на:
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
Тридцатая ячейка инициализирует экземпляр класса Trainer
. Установите тридцатую ячейку на:
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
инициализируется с помощью:model
, инициализированная на шаге 3.25 .Dataset
train_data
из шага 3.24 .Dataset
valid_data
из шага 3.24 .tokenizer
назначается processor.feature_extractor
и работает с data_collator
для автоматического дополнения входных данных до входных данных максимальной длины каждого мини-пакета. Тридцать первая ячейка вызывает метод train
экземпляра класса Trainer
для точной настройки модели. Установите тридцать первую ячейку на:
### CELL 31: Finetune the model ### trainer.train()
Тридцать вторая ячейка — последняя ячейка блокнота. Он сохраняет настроенную модель, вызывая метод save_model
в экземпляре Trainer
. Установите тридцать вторую ячейку на:
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
Теперь, когда все ячейки блокнота собраны, пришло время приступить к тонкой настройке.
Настройте ноутбук Kaggle для работы с ускорителем NVIDIA GPU P100 .
Зафиксируйте блокнот на Kaggle.
Отслеживайте данные тренировочного прогона, войдя в свою учетную запись WandB и найдя соответствующий прогон.
Обучение в течение 30 эпох должно занять около 5 часов с использованием ускорителя NVIDIA GPU P100. WER для неактивных данных должен упасть до ~0,15 в конце обучения. Это не совсем современный результат, но точно настроенная модель по-прежнему достаточно полезна для многих приложений.
Точно настроенная модель будет выведена в каталог Kaggle, указанный константой OUTPUT_DIR_PATH
указанной в шаге 3.5 . Выходные данные модели должны включать следующие файлы:
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
Эти файлы можно загрузить локально. Кроме того, вы можете создать новую модель Kaggle, используя файлы модели. Модель Kaggle будет использоваться вместе с сопутствующим руководством по выводу для выполнения вывода на основе точно настроенной модели.
Поздравляем с тонкой настройкой wav2vec2 XLS-R! Помните, что вы можете использовать эти общие шаги для точной настройки модели на других языках, которые вам нужны. Сделать вывод на основе точно настроенной модели, созданной в этом руководстве, довольно просто. Этапы вывода будут изложены в отдельном сопутствующем руководстве к этому руководству. Пожалуйста, выполните поиск по моему имени пользователя HackerNoon, чтобы найти сопутствующее руководство.