Introduction
This is a companion guide to Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition (the "Part 1 guide"). I wrote the Part 1 guide on how to finetune Meta AI's wav2vec2 XLS-R ("XLS-R") model on Chilean Spanish. It is assumed that you have completed that guide and have generated your own finetuned XLS-R model. This guide will explain the steps to run inference on your finetuned XLS-R model via a Kaggle Notebook.
Prerequisites and Before You Get Started
To complete the guide, you will need to have:
- A finetuned XLS-R model for the Spanish language.
- An existing Kaggle account.
- Intermediate knowledge of Python.
- Intermediate knowledge of working with Kaggle Notebooks.
- Intermediate knowledge of ML concepts.
- Basic knowledge of ASR concepts.
Building the Inference Notebook
Step 1 - Setting Up Your Kaggle Environment
Step 1.1 - Creating a New Kaggle Notebook
- Log in to Kaggle.
- Create a new Kaggle Notebook.
- The name of the notebook can be changed as desired. This guide uses the notebook name
spanish-asr-inference
.
Step 1.2 - Adding the Test Datasets
This guide uses the Peruvian Spanish Speech Data Set as its source for test data. Like the Chilean Spanish Speech Data Set, the Peruvian speakers dataset also consists of two sub-datasets: 2,918 recordings of male Peruvian speakers and 2,529 recordings of female Peruvian speakers.
This dataset has been uploaded to Kaggle as 2 distinct datasets:
Add both of these datasets to your Kaggle Notebook by clicking on Add Input.
Step 1.3 - Adding the Finetuned Model
You should have saved your finetuned model in Step 4 of the Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition guide as a Kaggle Model.
Add your finetuned model to your Kaggle Notebook by clicking on Add Input.
Step 2 - Building the Inference Notebook
The following 16 sub-steps build each of the inference notebook's 16 cells in order. You will note that many of the same utility methods from the Part 1 guide are used here.
Step 2.1 - CELL 1: Installing Packages
The first cell of the inference notebook installs dependencies. Set the first cell to:
### CELL 1: Install Packages ###
!pip install --upgrade torchaudio
!pip install jiwer
Step 2.2 - CELL 2: Importing Python Packages
The second cell imports required Python packages. Set the second cell to:
### CELL 2: Import Python packages ###
import re
import math
import random
import pandas as pd
import torchaudio
from datasets import load_metric
from transformers import pipeline
Step 2.3 - CELL 3: Loading WER Metric
The third cell imports the HuggingFace WER evaluation metric. Set the third cell to:
### CELL 3: Load WER metric ###
wer_metric = load_metric("wer")
- WER will be used to measure the performance of the finetuned model on test data.
Step 2.4 - CELL 4: Setting Constants
The fourth cell sets constants that will be used throughout the notebook. Set the fourth cell to:
### CELL 4: Constants ###
# Testing data
TEST_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-peru-male/"
TEST_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-peru-female/"
EXT = ".wav"
NUM_LOAD_FROM_EACH_SET = 3
# Special characters
SPECIAL_CHARS = r"[\d\,\-\;\!\ÂĄ\?\Âż\à„€\â\â\"\â\'\:\/\.\â\â\à§·\âŠ\â\à„„\\]"
# Sampling rates
ORIG_SAMPLING_RATE = 48000
TGT_SAMPLING_RATE = 16000
Step 2.5 - CELL 5: Utility Methods for Reading Index Files, Cleaning Text, and Creating Vocabulary
The fifth cell defines utility methods for reading the dataset index files, as well as for cleaning transcription text and generating a random set of samples from test data. Set the fifth cell to:
### CELL 5: Utility methods for reading index files, cleaning text, random indices generator ###
def read_index_file_data(path: str, filename: str):
data = []
with open(path + filename, "r", encoding = "utf8") as f:
lines = f.readlines()
for line in lines:
file_and_text = line.split("\t")
data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")])
return data
def clean_text(text: str) -> str:
cleaned_text = re.sub(SPECIAL_CHARS, "", text)
cleaned_text = cleaned_text.lower()
return cleaned_text
def get_random_samples(dataset: list, num: int) -> list:
used = []
samples = []
for i in range(num):
a = -1
while a == -1 or a in used:
a = math.floor(len(dataset) * random.random())
samples.append(dataset[a])
used.append(a)
return samples
-
The
read_index_file_data
method reads aline_index.tsv
dataset index file and produces a list of lists with audio filename and transcription data, e.g.:
[
["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"]
...
]
- The
clean_text
method is used to strip each text transcription of the characters specified by the regular expression assigned toSPECIAL_CHARS
in Step 2.4. These characters, inclusive of punctuation, can be eliminated as they don't provide any semantic value when training the model to learn mappings between audio features and text transcriptions. - The
get_random_samples
method returns a set of random test samples with the quantity set by the constantNUM_LOAD_FROM_EACH_SET
in Step 2.4.
Step 2.6 - CELL 6: Utility Methods for Loading and Resampling Audio Data
The sixth cell defines utility methods using torchaudio
to load and resample audio data. Set the sixth cell to:
### CELL 7: Utility methods for loading and resampling audio data ###
def read_audio_data(file):
speech_array, sampling_rate = torchaudio.load(file, normalize = True)
return speech_array, sampling_rate
def resample(waveform):
transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE)
waveform = transform(waveform)
return waveform[0]
- The
read_audio_data
method loads a specified audio file and returns atorch.Tensor
multi-dimensional matrix of the audio data along with the sampling rate of the audio. All the audio files in the training data have a sampling rate of48000
Hz. This "original" sampling rate is captured by the constantORIG_SAMPLING_RATE
in Step 2.4. - The
resample
method is used to downsample audio data from a sampling rate of48000
to the target sampling rate of16000
.
Step 2.7 - CELL 7: Reading Test Data
The seventh cell reads the test data index files for the recordings of male speakers and the recordings of female speakers using the read_index_file_data
method defined in Step 2.5. Set the seventh cell to:
### CELL 7: Read test data ###
test_data_male = read_index_file_data(TEST_DATA_PATH_MALE, "line_index.tsv")
test_data_female = read_index_file_data(TEST_DATA_PATH_FEMALE, "line_index.tsv")
Step 2.8 - CELL 8: Generating Lists of Random Test Samples
The eighth cell generates sets of random test samples using the get_random_samples
method defined in Step 2.5. Set the eighth cell to:
### CELL 8: Generate lists of random test samples ###
random_test_samples_male = get_random_samples(test_data_male, NUM_LOAD_FROM_EACH_SET)
random_test_samples_female = get_random_samples(test_data_female, NUM_LOAD_FROM_EACH_SET)
Step 2.9 - CELL 9: Combining Test Data
The ninth cell combines the male test samples and female test samples into a single list. Set the ninth cell to:
### CELL 9: Combine test data ###
all_test_samples = random_test_samples_male + random_test_samples_female
Step 2.10 - CELL 10: Cleaning Transcription Test
The tenth cell iterates over each test data sample and cleans the associated transcription text using the clean_text
method defined in Step 2.5. Set the tenth cell to:
### CELL 10: Clean text transcriptions ###
for index in range(len(all_test_samples)):
all_test_samples[index][1] = clean_text(all_test_samples[index][1])
Step 2.11 - CELL 11: Loading Audio Data
The eleventh cell loads each audio file specified in the all_test_samples
list. Set the eleventh cell to:
### CELL 11: Load audio data ###
all_test_data = []
for index in range(len(all_test_samples)):
speech_array, sampling_rate = read_audio_data(all_test_samples[index][0])
all_test_data.append({
"raw": speech_array,
"sampling_rate": sampling_rate,
"target_text": all_test_samples[index][1]
})
- Audio data is returned as a
torch.Tensor
and stored inall_test_data
as a list of dictionaries. Each dictionary contains the audio data for a particular sample, the sampling rate, and the text transcription of the audio.
Step 2.12 - CELL 12: Resampling Audio Data
The twelfth cell resamples audio data to the target sampling rate of 16000
. Set the twelfth cell to:
### CELL 12: Resample audio data and cast to NumPy arrays ###
all_test_data = [{"raw": resample(sample["raw"]).numpy(), "sampling_rate": TGT_SAMPLING_RATE, "target_text": sample["target_text"]} for sample in all_test_data]
Step 2.13 - CELL 13: Initializing Instance of Automatic Speech Recognition Pipeline
The thirteenth cell initializes an instance of the HuggingFace transformer
library pipeline
class. Set the thirteenth cell to:
### CELL 13: Initialize instance of Automatic Speech Recognition Pipeline ###
transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")
-
The
model
parameter must be set to the path to your finetuned model added to the Kaggle Notebook in Step 1.3, e.g.:
transcriber = pipeline("automatic-speech-recognition", model = "/kaggle/input/xls-r-300m-chilean-spanish/transformers/hardy-pine/1")
Step 2.14 - CELL 14: Generating Predictions
The fourteenth cell calls the transcriber
initialized in the previous step on the test data to generate text predictions. Set the fourteenth cell to:
### CELL 14: Generate transcriptions ###
transcriptions = transcriber(all_test_data)
Step 2.15 - CELL 15: Calculating WER Metrics
The fifteenth cell calculates WER scores for each prediction as well as an overall WER score for all predictions. Set the fifteenth cell to:
### CELL 15: Calculate WER metrics ###
predictions = [transcription["text"] for transcription in transcriptions]
references = [transcription["target_text"][0] for transcription in transcriptions]
wers = []
for p in range(len(predictions)):
wer = wer_metric.compute(predictions = [predictions[p]], references = [references[p]])
wers.append(wer)
zipped = list(zip(predictions, references, wers))
df = pd.DataFrame(zipped, columns=["Prediction", "Reference", "WER"])
wer = wer_metric.compute(predictions = predictions, references = references)
Step 2.16 - CELL 16: Printing WER Metrics
The sixteenth and final cell simply prints the WER calculations from the previous step. Set the sixteenth cell to:
### CELL 16: Output WER metrics ###
pd.set_option("display.max_colwidth", None)
print(f"Overall WER: {wer}")
print(df)
WER Analysis
Since the notebook generates predictions on random samples of test data, the output will vary each time the notebook is run. The following output was generated on a run of the notebook with NUM_LOAD_FROM_EACH_SET
set to 3
for a total of 6 test samples:
Overall WER: 0.013888888888888888
Prediction \
0 quiero que me reserves el mejor asiento del teatro
1 el llano en llamas es un clĂĄsico de juan rulfo
2 el cuadro de los alcatraces es una de las pinturas mĂĄs famosas de diego rivera
3 hay tres cafés que estån abiertos hasta las once de la noche
4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras
5 cuåntos albergues se abrieron después del terremoto del diecinueve de setiembre
Reference \
0 quiero que me reserves el mejor asiento del teatro
1 el llano en llamas es un clĂĄsico de juan rulfo
2 el cuadro de los alcatraces es una de las pinturas mĂĄs famosas de diego rivera
3 hay tres cafés que estån abiertos hasta las once de la noche
4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras
5 cuåntos albergues se abrieron después del terremoto del diecinueve de septiembre
WER
0 0.000000
1 0.000000
2 0.000000
3 0.000000
4 0.000000
5 0.090909
As can be seen, the model did an excellent job! It only made one error with the sixth sample (index 5
), mis-spelling the word septiembre
as setiembre
. Of course, running the notebook again with different test samples and, more importantly, a larger number of test samples, will produce different and more informative results. Nonetheless, this limited data suggests the model can perform well on different dialects of Spanish - i.e. it was trained on Chilean Spanish, but appears to perform well on Peruvian Spanish.
Conclusion
If you are just learning how to work with wav2vec2 models, I hope that the Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition guide and this guide were useful for you. As mentioned, the finetuned model generated by the Part 1 guide is not quite state-of-the-art, but should still prove useful for many applications. Happy building!