paint-brush
Working with wav2vec2 Part 2 - Running Inference on Finetuned ASR Modelsby@pictureinthenoise
195 reads

Working with wav2vec2 Part 2 - Running Inference on Finetuned ASR Models

by Picture in the NoiseMay 7th, 2024
Read on Terminal Reader
Read this story w/o Javascript

Too Long; Didn't Read

This companion guide explains the steps to run inference on a finetuned wav2vec2 XLS-R model. It complements the guide "Working with wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition". The guide provides step-by-step instructions on creating a Kaggle Notebook that can be used for running inference.
featured image - Working with wav2vec2 Part 2 - Running Inference on Finetuned ASR Models
Picture in the Noise HackerNoon profile picture

Introduction

This is a companion guide to Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition (the "Part 1 guide"). I wrote the Part 1 guide on how to finetune Meta AI's wav2vec2 XLS-R ("XLS-R") model on Chilean Spanish. It is assumed that you have completed that guide and have generated your own finetuned XLS-R model. This guide will explain the steps to run inference on your finetuned XLS-R model via a Kaggle Notebook.

Prerequisites and Before You Get Started

To complete the guide, you will need to have:


  • A finetuned XLS-R model for the Spanish language.
  • An existing Kaggle account.
  • Intermediate knowledge of Python.
  • Intermediate knowledge of working with Kaggle Notebooks.
  • Intermediate knowledge of ML concepts.
  • Basic knowledge of ASR concepts.

Building the Inference Notebook

Step 1 - Setting Up Your Kaggle Environment

Step 1.1 - Creating a New Kaggle Notebook

  1. Log in to Kaggle.
  2. Create a new Kaggle Notebook.
  3. The name of the notebook can be changed as desired. This guide uses the notebook name spanish-asr-inference.

Step 1.2 - Adding the Test Datasets

This guide uses the Peruvian Spanish Speech Data Set as its source for test data. Like the Chilean Spanish Speech Data Set, the Peruvian speakers dataset also consists of two sub-datasets: 2,918 recordings of male Peruvian speakers and 2,529 recordings of female Peruvian speakers.


This dataset has been uploaded to Kaggle as 2 distinct datasets:


Add both of these datasets to your Kaggle Notebook by clicking on Add Input.

Step 1.3 - Adding the Finetuned Model

You should have saved your finetuned model in Step 4 of the Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition guide as a Kaggle Model.


Add your finetuned model to your Kaggle Notebook by clicking on Add Input.

Step 2 - Building the Inference Notebook

The following 16 sub-steps build each of the inference notebook's 16 cells in order. You will note that many of the same utility methods from the Part 1 guide are used here.

Step 2.1 - CELL 1: Installing Packages

The first cell of the inference notebook installs dependencies. Set the first cell to:


### CELL 1: Install Packages ###
!pip install --upgrade torchaudio
!pip install jiwer

Step 2.2 - CELL 2: Importing Python Packages

The second cell imports required Python packages. Set the second cell to:


### CELL 2: Import Python packages ###
import re
import math
import random
import pandas as pd
import torchaudio
from datasets import load_metric
from transformers import pipeline

Step 2.3 - CELL 3: Loading WER Metric

The third cell imports the HuggingFace WER evaluation metric. Set the third cell to:


### CELL 3: Load WER metric ###
wer_metric = load_metric("wer")


  • WER will be used to measure the performance of the finetuned model on test data.

Step 2.4 - CELL 4: Setting Constants

The fourth cell sets constants that will be used throughout the notebook. Set the fourth cell to:


### CELL 4: Constants ###
# Testing data
TEST_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-peru-male/"
TEST_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-peru-female/"
EXT = ".wav"
NUM_LOAD_FROM_EACH_SET = 3

# Special characters
SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\‘\’\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]"

# Sampling rates
ORIG_SAMPLING_RATE = 48000
TGT_SAMPLING_RATE = 16000

Step 2.5 - CELL 5: Utility Methods for Reading Index Files, Cleaning Text, and Creating Vocabulary

The fifth cell defines utility methods for reading the dataset index files, as well as for cleaning transcription text and generating a random set of samples from test data. Set the fifth cell to:


### CELL 5: Utility methods for reading index files, cleaning text, random indices generator ###
def read_index_file_data(path: str, filename: str):
    data = []
    with open(path + filename, "r", encoding = "utf8") as f:
        lines = f.readlines()
        for line in lines:
            file_and_text = line.split("\t")
            data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")])
    return data

def clean_text(text: str) -> str:
    cleaned_text = re.sub(SPECIAL_CHARS, "", text)
    cleaned_text = cleaned_text.lower()
    return cleaned_text

def get_random_samples(dataset: list, num: int) -> list:
    used = []
    samples = []
    for i in range(num):
        a = -1
        while a == -1 or a in used:
            a = math.floor(len(dataset) * random.random())
        samples.append(dataset[a])
        used.append(a)
    return samples           


  • The read_index_file_data method reads a line_index.tsv dataset index file and produces a list of lists with audio filename and transcription data, e.g.:


[
    ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"]
    ...
]


  • The clean_text method is used to strip each text transcription of the characters specified by the regular expression assigned to SPECIAL_CHARS in Step 2.4. These characters, inclusive of punctuation, can be eliminated as they don't provide any semantic value when training the model to learn mappings between audio features and text transcriptions.
  • The get_random_samples method returns a set of random test samples with the quantity set by the constant NUM_LOAD_FROM_EACH_SET in Step 2.4.

Step 2.6 - CELL 6: Utility Methods for Loading and Resampling Audio Data

The sixth cell defines utility methods using torchaudio to load and resample audio data. Set the sixth cell to:


### CELL 7: Utility methods for loading and resampling audio data ###
def read_audio_data(file):
    speech_array, sampling_rate = torchaudio.load(file, normalize = True)
    return speech_array, sampling_rate

def resample(waveform):
    transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE)
    waveform = transform(waveform)
    return waveform[0]


  • The read_audio_data method loads a specified audio file and returns a torch.Tensor multi-dimensional matrix of the audio data along with the sampling rate of the audio. All the audio files in the training data have a sampling rate of 48000 Hz. This "original" sampling rate is captured by the constant ORIG_SAMPLING_RATE in Step 2.4.
  • The resample method is used to downsample audio data from a sampling rate of 48000 to the target sampling rate of 16000.

Step 2.7 - CELL 7: Reading Test Data

The seventh cell reads the test data index files for the recordings of male speakers and the recordings of female speakers using the read_index_file_data method defined in Step 2.5. Set the seventh cell to:


### CELL 7: Read test data ###
test_data_male = read_index_file_data(TEST_DATA_PATH_MALE, "line_index.tsv")
test_data_female = read_index_file_data(TEST_DATA_PATH_FEMALE, "line_index.tsv")

Step 2.8 - CELL 8: Generating Lists of Random Test Samples

The eighth cell generates sets of random test samples using the get_random_samples method defined in Step 2.5. Set the eighth cell to:


### CELL 8: Generate lists of random test samples ###
random_test_samples_male = get_random_samples(test_data_male, NUM_LOAD_FROM_EACH_SET)
random_test_samples_female = get_random_samples(test_data_female, NUM_LOAD_FROM_EACH_SET)

Step 2.9 - CELL 9: Combining Test Data

The ninth cell combines the male test samples and female test samples into a single list. Set the ninth cell to:


### CELL 9: Combine test data ###
all_test_samples = random_test_samples_male + random_test_samples_female

Step 2.10 - CELL 10: Cleaning Transcription Test

The tenth cell iterates over each test data sample and cleans the associated transcription text using the clean_text method defined in Step 2.5. Set the tenth cell to:


### CELL 10: Clean text transcriptions ###
for index in range(len(all_test_samples)):
    all_test_samples[index][1] = clean_text(all_test_samples[index][1])

Step 2.11 - CELL 11: Loading Audio Data

The eleventh cell loads each audio file specified in the all_test_samples list. Set the eleventh cell to:


### CELL 11: Load audio data ###
all_test_data = []

for index in range(len(all_test_samples)):
    speech_array, sampling_rate = read_audio_data(all_test_samples[index][0])
    all_test_data.append({
        "raw": speech_array,
        "sampling_rate": sampling_rate,
        "target_text": all_test_samples[index][1]
    })


  • Audio data is returned as a torch.Tensor and stored in all_test_data as a list of dictionaries. Each dictionary contains the audio data for a particular sample, the sampling rate, and the text transcription of the audio.

Step 2.12 - CELL 12: Resampling Audio Data

The twelfth cell resamples audio data to the target sampling rate of 16000. Set the twelfth cell to:


### CELL 12: Resample audio data and cast to NumPy arrays ###
all_test_data = [{"raw": resample(sample["raw"]).numpy(), "sampling_rate": TGT_SAMPLING_RATE, "target_text": sample["target_text"]} for sample in all_test_data]

Step 2.13 - CELL 13: Initializing Instance of Automatic Speech Recognition Pipeline

The thirteenth cell initializes an instance of the HuggingFace transformer library pipeline class. Set the thirteenth cell to:


### CELL 13: Initialize instance of Automatic Speech Recognition Pipeline ###
transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")


  • The model parameter must be set to the path to your finetuned model added to the Kaggle Notebook in Step 1.3, e.g.:


transcriber = pipeline("automatic-speech-recognition", model = "/kaggle/input/xls-r-300m-chilean-spanish/transformers/hardy-pine/1")

Step 2.14 - CELL 14: Generating Predictions

The fourteenth cell calls the transcriber initialized in the previous step on the test data to generate text predictions. Set the fourteenth cell to:


### CELL 14: Generate transcriptions ###
transcriptions = transcriber(all_test_data)

Step 2.15 - CELL 15: Calculating WER Metrics

The fifteenth cell calculates WER scores for each prediction as well as an overall WER score for all predictions. Set the fifteenth cell to:


### CELL 15: Calculate WER metrics ###
predictions = [transcription["text"] for transcription in transcriptions]
references = [transcription["target_text"][0] for transcription in transcriptions]

wers = []
for p in range(len(predictions)):
    wer = wer_metric.compute(predictions = [predictions[p]], references = [references[p]])
    wers.append(wer)

zipped = list(zip(predictions, references, wers))
df = pd.DataFrame(zipped, columns=["Prediction", "Reference", "WER"])
wer = wer_metric.compute(predictions = predictions, references = references)

Step 2.16 - CELL 16: Printing WER Metrics

The sixteenth and final cell simply prints the WER calculations from the previous step. Set the sixteenth cell to:


### CELL 16: Output WER metrics ###
pd.set_option("display.max_colwidth", None)
print(f"Overall WER: {wer}")
print(df)

WER Analysis

Since the notebook generates predictions on random samples of test data, the output will vary each time the notebook is run. The following output was generated on a run of the notebook with NUM_LOAD_FROM_EACH_SET set to 3 for a total of 6 test samples:


Overall WER: 0.013888888888888888
                                                                              Prediction  \
0                                     quiero que me reserves el mejor asiento del teatro   
1                                         el llano en llamas es un clásico de juan rulfo   
2         el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera   
3                           hay tres cafés que están abiertos hasta las once de la noche   
4  quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras   
5        cuántos albergues se abrieron después del terremoto del diecinueve de setiembre   

                                                                               Reference  \
0                                     quiero que me reserves el mejor asiento del teatro   
1                                         el llano en llamas es un clásico de juan rulfo   
2         el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera   
3                           hay tres cafés que están abiertos hasta las once de la noche   
4  quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras   
5       cuántos albergues se abrieron después del terremoto del diecinueve de septiembre   

        WER  
0  0.000000  
1  0.000000  
2  0.000000  
3  0.000000  
4  0.000000  
5  0.090909


As can be seen, the model did an excellent job! It only made one error with the sixth sample (index 5), mis-spelling the word septiembre as setiembre. Of course, running the notebook again with different test samples and, more importantly, a larger number of test samples, will produce different and more informative results. Nonetheless, this limited data suggests the model can perform well on different dialects of Spanish - i.e. it was trained on Chilean Spanish, but appears to perform well on Peruvian Spanish.

Conclusion

If you are just learning how to work with wav2vec2 models, I hope that the Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition guide and this guide were useful for you. As mentioned, the finetuned model generated by the Part 1 guide is not quite state-of-the-art, but should still prove useful for many applications. Happy building!