paint-brush
Working with wav2vec2 Part 3 - Using ASR Models for Long Inferenceby@pictureinthenoise
233 reads

Working with wav2vec2 Part 3 - Using ASR Models for Long Inference

by Picture in the NoiseMay 20th, 2024
Read on Terminal Reader
Read this story w/o Javascript
tldt arrow

Too Long; Didn't Read

It is more computationally expensive to process a longer audio file than a shorter one. Chunking is a technique that we can employ to make wav2vec2 finetuned automatic speech recognition ("ASR") models work on long audio files. This guide walks through the steps to build a simple Python application that can run ASR inference on long audio files using chunking.
featured image - Working with wav2vec2 Part 3 - Using ASR Models for Long Inference
Picture in the Noise HackerNoon profile picture


This is a companion guide to Working With wav2vec2 Part 1 - Finetuning XLS-R for Automatic Speech Recognition (the "Part 1 guide") and Working With wav2vec2 Part 2 - Running Inference on Finetuned ASR Models (the "Part 2 guide"). In those guides, I outlined the steps to generate text transcriptions from audio using a finetuned wav2vec2 ASR model. Those readers who opted to train wav2vec2 on the Chilean Spanish dataset that I used for my model likely noticed that the audio samples were (generally) less than 10 seconds long. While the inference notebook in Part 2 can theoretically be used with longer audio inputs, realistically it will "choke" on large audio files (i.e. audio files more than a few seconds in length).


Chunking is a simple technique that we can employ to make wav2vec2 finetuned models work on long audio files. This guide walks through the steps to build a simple Python application that can run inference on long audio files.


You will build this:


This guide describes the steps to build this Python application to run long ASR inference



It is assumed that you have completed the Part 1 and Part 2 guides and that you have generated your own finetuned wav2vec2 XLS-R model. This guide continues working with the Spanish language, but the Python application that you will build can be used with models finetuned on other languages.


Chunking

Basic logic tells us that it is more computationally expensive to process a longer audio file than a shorter one. In the case of the transformer architecture, the computational complexity of the attention mechanism is quadratic with respect to the length of the sequence fed into the transformer. So, larger and larger sequences drive order-of-magnitude increases in computational requirements.

Chunking is a technique that we can use to get around this limitation. Simply put, we can:


  • Divide a long sequence of audio into chunks of fixed lengths, e.g. 15 seconds.
  • Run inference on each chunk - i.e. generate a text transcription for each individual chunk.
  • Concatenate the chunk-specific text transcriptions to create a complete transcription for the long audio file.


While the approach above will work, the beginning and end of each audio chunk will "see" poor inference performance. This is because wav2vec2 performs inference on a given section of audio using the context of that section - i.e. the audio "around" it. Since, by definition, the beginning and end of each audio chunk have no context, inference results are expected to be poor in those sections.

A solution to this problem is to add some amount of context, a stride, to the beginning and to the end of each audio chunk solely for the purpose of running inference. After the inference is complete, we drop the inference results for the added context so that we are left only with the inference results for the audio chunk. Just as before, we can concatenate the individual text transcriptions to create a complete transcription, but with the benefit of better inference at the beginning and end of each audio chunk.


As will be seen in the next section, configuring a Hugging Face pipeline to use chunking and a context stride is very simple. For additional discussion on this topic, please see this excellent Hugging Face blog post, which also includes a visual depiction of the chunking approach.


Configuring the pipeline Class for Chunking

As you might remember from the Part 2 guide, we configured an instance of the Hugging Face pipeline class for automatic speech recognition ("ASR"). Specifically, in Step 2.13 of that guide, we initialized a transcriber as follows:


transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")


The pipeline class for ASR accepts two additional initialization arguments that allow us to implement chunking and to add a context stride to each chunk:


  • chunk_length_s: An integer specifying the length in seconds of each chunk.

  • stride_length_s: A tuple of integers specifying the stride lengths in seconds at the beginning and end of each chunk.


For example, if we wanted to use a chunk size of 8 seconds, and a stride length of 2 seconds on each side of the chunk, we would initialize the pipeline using:


transcriber = pipeline("automatic-speech-recognition", chunk_length_s = 8, stride_length_s = (2,2), model = "YOUR_FINETUNED_MODEL_PATH")


With these new parameters in mind, you're now ready to build the Python application for long inference.


Prerequisites and Before You Get Started

To complete the guide, you will need to have:


  • A finetuned wav2vec2 model.
  • Intermediate knowledge of Python.
  • Basic knowledge of the Python tkinter package.
  • Intermediate knowledge of ML concepts.
  • Basic knowledge of ASR concepts.


Building the Python Application for Long Inference

Step 1 - Setting Up Your Python Environment

Use pip to install the torchaudio and transformers packages in your Python environment if they are not already.


pip install torchaudio
pip install transformers


You might also need to install the soundfile package to use the torchaudio.load method which is used to load audio files.

Step 2 - Building the Application

The following sub-steps build the application progressively. Each sub-step includes an explanation that describes the particular section of code. If you prefer to view the finished application, feel free to jump to Step 3 which displays the complete program.

Step 2.1 - Creating a New Python Application and Adding Imports

Create a new Python application and add the following import statements:


import time
import torch
import torchaudio
from transformers import pipeline
import tkinter as tk
import tkinter.scrolledtext as scrolledtext
from tkinter import ttk, filedialog


  • The torchaudio package will be used to load and resample audio data.
  • The transformers package, and specifically the pipeline class, will be used to run inference.
  • tkinter will be used to create the user interface for the application.


It is assumed you have an existing background working with tkinter GUIs. However, if you need a brief primer, you can take a look at my Hackernoon article Building Your First Python GUI With Tkinter.

Step 2.2 - Adding Constants

Add the following constants below the module imports.


GUI = {
    "title": "Long Inference with wav2vec2 ASR Models",
    "root_width" : 800,
    "root_height" : 500,
    "pad_x": 10,
    "pad_y": 10,
    "input_field_width": 110,
    "textbox_width": 125,
    "textbox_height": 15,
    "select_model_label": "Please select an ASR model",
    "browse_models_button_label": "Browse Models",
    "select_file_label": "Please select an audio file for inference",
    "browse_files_button_label": "Browse Files",
    "run_inference_button_label": "Run",
    "save_to_file_button_label": "Save To File",
    "reset": "Reset",
    "default_notification": "",
    "notification_select_model": "You need to select an ASR model directory",
    "notification_select_audio_file": "You need to select an audio file",
    "notification_select_model_and_audio_file": "You need to select an ASR model and audio file before you can run inference",
    "notification_running_inference": "Running inference...this might take awhile...",
    "notification_finished_inference": "Finished running inference in {} seconds",
    "notification_select_file_for_saving": "You need to select a file to save the transcription",
    "notification_finished_saving": "Finished saving transcription"
}
TGT_SAMPLING_RATE = 16000
CHUNK_LENGTH = 10
STRIDE_START = 3
STRIDE_END = 3


  • The GUI dictionary captures window dimensions, labels, and other data used to create and update the user interface.
  • TGT_SAMPLING_RATE is the target sampling rate, expressed in Hz, used when resampling audio data.
  • CHUNK_LENGTH is the chunking length expressed in seconds and will be used when initializing the pipeline class for ASR. In this guide, each audio chunk will be 10 seconds in length.
  • STRIDE_START and STRIDE_END represent the starting and ending context strides applied to each chunk. Both values are expressed in seconds and will be used when initializing the pipeline class for ASR. With both values equal to 3, an equivalent stride length of 3 seconds will be applied to both the start and end of each audio chunk.

Step 2.3 - Adding Globals

Add the following two global variables below the constant variables.


MODEL = None
AUDIO_FILE = None


  • The MODEL and AUDIO_FILE globals will be assigned to the ASR model directory path and audio file path respectively.
  • As you will see, the application enforces a selection order whereby an ASR model must be chosen before an audio file can be selected.

Step 2.4 - Adding Utility Methods

Add the following utility methods below the MODEL and AUDIO_FILE globals.


def read_audio_data(file: str) -> tuple[torch.Tensor, int]:
    audio, sampling_rate = torchaudio.load(file, normalize = True)
    return audio, sampling_rate

def resample(waveform: torch.Tensor, orig_sampling_rate: int) -> torch.Tensor:
    transform = torchaudio.transforms.Resample(orig_sampling_rate, TGT_SAMPLING_RATE)
    waveform = transform(waveform)
    return waveform[0]


  • The read_audio_data method is used to load audio files using the torchaudio.load method.
  • The resample method is used to resample audio from its original sampling rate to the target sampling rate of 16000 as specified by TGT_SAMPLING_RATE.

Step 2.5 - Adding Callback Methods

There are five callback methods, each of which are bound to one of five user interface buttons:


  • Browse Models button

  • Browse Files button

  • Run button

  • Save To File button

  • Reset button


When a given button is pressed, its respective callback method is called.

Step 2.5.1 - Adding callback_select_model

Add the following code for the callback_select_model method:


def callback_select_model(e: object, args: list) -> None:
    global MODEL
    
    # Unpack widgets
    select_model_field = args[0]
    notification_label = args[1]

    # Open file dialog
    dir = filedialog.askdirectory()

    # Add `MODEL` to the `select_model_field` widget
    if dir:
        MODEL = dir
        select_model_field.configure({"text": MODEL})
        select_model_field.update()
    else:
        notification_label.configure({"text": GUI["notification_select_model"]})
        notification_label.update()


  • This method is called when the Browse Models button is pressed on the GUI.
  • A file dialog is opened which the user can use to select an ASR model directory.
  • If the user closes the dialog before selecting a valid directory, a warning notifcation is displayed.
  • If the user selects a valid directory, the directory path is added to the GUI and assigned to the MODEL global.

Step 2.5.2 - Adding callback_select_file

Add the following code for the callback_select_file method:


def callback_select_file(e: object, args: list) -> None:
    global AUDIO_FILE

    # Unpack widgets
    select_file_field = args[0]
    notification_label = args[1]

    if not MODEL:
        notification_label.configure({"text": GUI["notification_select_model"]})
        notification_label.update()
    else:
        # Open file dialog
        file = filedialog.askopenfilename()

        # Add `AUDIO_FILE` filename to the `select_file_field`
        if file:
            AUDIO_FILE = file
            select_file_field.configure({"text": AUDIO_FILE})
            select_file_field.update()
        else:
            notification_label.configure({"text": GUI["notification_select_audio_file"]})
            notification_label.update()


  • This method is called when the Browse Files button is pressed on the GUI.
  • The logic first checks if an ASR model directory has been selected. If a model directory has not yet been chosen, a warning notification is displayed.
  • If an ASR model directory has been chosen, a file dialog is opened which the user can use to select an audio file.
  • If the user closes the dialog before selecting a valid file, a warning notifcation is displayed.
  • If the user selects a valid file, the file path is added to the GUI and assigned to the AUDIO_FILE global.

Step 2.5.3 - Adding callback_run_inference

Add the following code for the callback_run_inference method:


def callback_run_inference(e: object, args: list) -> None:
    global MODEL
    global AUDIO_FILE

    # Unpack widgets
    textbox = args[0]
    notification_label = args[1]
    save_to_file_button = args[2]
    reset_button = args[3]
    select_model_button = args[4]
    select_file_button = args[5]

    if not MODEL or not AUDIO_FILE:
        notification_label.configure({"text": GUI["notification_select_model_and_audio_file"]})
        notification_label.update()
    else:
        # Disable select model and select file buttons when running inference
        select_model_button.unbind("<Button>")
        select_file_button.unbind("<Button>")

        # Set the input list for the ASR pipline
        pipeline_input = []
        orig_audio, orig_sampling_rate = read_audio_data(AUDIO_FILE)
        resampled_audio = resample(orig_audio, orig_sampling_rate)
        pipeline_input.append({
            "raw": resampled_audio.numpy(),
            "sampling_rate": TGT_SAMPLING_RATE
        })

        # Initialize instance of ASR pipeline
        transcriber = pipeline("automatic-speech-recognition", chunk_length_s = CHUNK_LENGTH, stride_length_s = (STRIDE_START, STRIDE_END), model = MODEL)
        
        # Update notification
        notification_label.configure({"text": GUI["notification_running_inference"]})
        notification_label.update()

        # Set start time
        start_time = time.time()

        # Run inference
        transcription = transcriber(pipeline_input)

        # Set end time
        end_time = time.time()

        # Add transcription to text box
        textbox.insert("1.0", transcription[0]["text"])
        textbox.update()

        # Update notification
        notification_label.configure({"text": GUI["notification_finished_inference"].format(str(int(end_time - start_time)))})
        notification_label.update()        
        
        # Bind `save_to_file_button`
        save_to_file_button.bind("<Button>", lambda e, args = [textbox, notification_label]: callback_save_file(e, args))


  • This method is called when the Run button is pressed on the GUI and is the heart of the application.
  • The logic first checks that an ASR model directory and audio file path have been selected. If either has not yet been selected, a warning notification is displayed.
  • If a valid ASR model directory and audio file have been selected, the logic will:
    • Disable the Browse Models and Browse Files buttons in preparation for running inference.
    • Load the chosen audio file.
    • Resample the audio data to the target sampling rate of 16,000 Hz.
    • Initialize the pipeline class for automatic speech recognition using the ASR model specified by MODEL, along with the CHUNK_LENGTH, STRIDE_START, and STRIDE_END values set earlier.
    • Run inference on the audio sample.
    • Add the complete text transcription to the GUI textbox for review.
  • If you completed the Part 2 guide, you will recognize that the inference logic mimics the inference logic in Step 2.12 through Step 2.14 of that guide, with the exception of initializing the pipeline class with the chunk_length_s and stride_length_s parameters.
  • A success notification is displayed after inference is complete with the inference duration expressed in seconds.
  • The Save To File button on the GUI is enabled after inference is complete.

Step 2.5.4 - Adding callback_save_file

Add the following code for the callback_save_file method:


def callback_save_file(e: object, args: list) -> None:
    # Unpack widgets
    textbox = args[0]
    notification_label = args[1]
    
    # Ask user to select a file for saving
    save_file = filedialog.asksaveasfilename()

    if save_file:
        # Write transcription to file
        transcription = textbox.get(1.0, tk.END)
        with open(save_file, "w", encoding = "utf8") as handle:
            handle.write(transcription)

        # Update notification
        notification_label.configure({"text": GUI["notification_finished_saving"]})
        notification_label.update()                 
    else:
        notification_label.configure({"text": GUI["select_file_for_saving"]})
        notification_label.update()


  • This method is called when the Save To File button is pressed on the GUI.
  • A file dialog is opened which the user can use to specify a filename and location for saving the generated text transcription.
  • If the user closes the dialog before specifying a valid filename, a warning notifcation is displayed.
  • If the user provides a valid filename, the transcription is written out to file using the specified filename. A success notification confirming the save is displayed for the user after the write is complete.

Step 2.5.5 - Adding callback_reset

Add the following code for the callback_reset method:


def callback_reset(e: object, args: list) -> None:
    global MODEL
    global AUDIO_FILE
    
    # Unpack widgets
    gui_root = args[0]
    select_model_field = args[1]
    select_model_button = args[2]
    select_file_field = args[3]
    select_file_button = args[4]
    textbox = args[5]
    save_to_file_button = args[6]
    notification_label = args[7]

    MODEL = None
    AUDIO_FILE = None

    select_model_field.configure({"text": ""})

    select_file_field.configure({"text": ""})

    textbox.delete("1.0", tk.END)

    notification_label.configure({"text": ""})

    select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
    select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
    save_to_file_button.unbind("<Button>")

    gui_root.update()


  • This method is called when the Reset button is pressed on the GUI. It resets the application for a new inference run. The logic:
    • Resets the MODEL and AUDIO_FILE globals to None.
    • Clears the existing ASR model directory path and audio file path data from the GUI.
    • Clears the transcription textbox.
    • Clears any displayed notification.
    • Re-enables the Browse Models and Browse Files buttons by binding those widgets to their respective callbacks.
    • Disables the Save To File button by unbinding it from its callback.

Step 2.6 - Adding main Method

Add the following code for the main method below the callback methods:


def main():
    gui_root = tk.Tk()
    gui_root.title(GUI["title"])
    window_width = GUI["root_width"]
    window_height = GUI["root_height"]

    # Get the screen dimensions
    screen_width = gui_root.winfo_screenwidth()
    screen_height = gui_root.winfo_screenheight()

    # Find the center point
    center_x = int(screen_width/2 - window_width/2)
    center_y = int(screen_height/2 - window_height/2)

    # Set the position of the window to the center of the screen
    gui_root.geometry(f"{window_width}x{window_height}+{center_x}+{center_y}")

    # Not resizable
    gui_root.resizable(False, False)

    # Configure grid
    gui_root.configure(padx = GUI["pad_x"])
    gui_root.columnconfigure(0, weight = 1)

    # Widgets
    select_model_label = ttk.Label(gui_root, text = GUI["select_model_label"])
    select_model_frame = ttk.Frame(gui_root)
    select_model_field = ttk.Label(select_model_frame, text = "")
    select_model_button = ttk.Button(select_model_frame, text = GUI["browse_models_button_label"])
    select_file_label = ttk.Label(gui_root, text = GUI["select_file_label"])
    select_file_frame = ttk.Frame(gui_root)
    select_file_field = ttk.Label(select_file_frame, text = "")
    select_file_button = ttk.Button(select_file_frame, text = GUI["browse_files_button_label"])
    run_inference_button = ttk.Button(gui_root, text = GUI["run_inference_button_label"])
    textbox = scrolledtext.ScrolledText(gui_root, width = GUI["textbox_width"], height = GUI["textbox_height"])
    textbox_buttons_frame = ttk.Frame(gui_root)
    save_to_file_button = ttk.Button(textbox_buttons_frame, text = GUI["save_to_file_button_label"])
    reset_button = ttk.Button(textbox_buttons_frame, text = GUI["reset"])
    notification_label = ttk.Label(gui_root, text = GUI["default_notification"])

    # Place widgets
    # Row 0
    select_model_label.grid(column = 0, row = 0, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    
    # Row 1
    select_model_frame.grid(column = 0, row = 1, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    select_model_field.config(background = "white", width = GUI["input_field_width"])
    select_model_field.pack(side = "left", padx = (0, GUI["pad_x"]))
    select_model_button.pack()
    
    # Row 2
    select_file_label.grid(column = 0, row = 2, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    
    # Row 3
    select_file_frame.grid(column = 0, row = 3, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    select_file_field.config(background = "white", width = GUI["input_field_width"])
    select_file_field.pack(side = "left", padx = (0, GUI["pad_x"]))
    select_file_button.pack()
    
    # Row 4
    run_inference_button.grid(column = 0, row = 4, sticky = tk.W, pady = (GUI["pad_y"], 0))

    # Row 5
    textbox.grid(column = 0, row = 5, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    
    # Row 6
    textbox_buttons_frame.grid(column = 0, row = 6, sticky = tk.W, pady = (GUI["pad_y"], 0))
    save_to_file_button.pack(side = "left")
    reset_button.pack()

    # Row 7
    notification_label.config(foreground = "blue")
    notification_label.grid(column = 0, row = 7, columnspan = 1, pady = (GUI["pad_y"], 0))

    # Bind buttons
    select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
    select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
    run_inference_button.bind("<Button>", lambda e, args = [textbox, notification_label, save_to_file_button, reset_button, select_model_button, select_file_button]: callback_run_inference(e, args))
    reset_button.bind("<Button>", lambda e, args = [gui_root, select_model_field, select_file_button, select_file_field, select_file_button, textbox, save_to_file_button, notification_label]: callback_reset(e, args))

    gui_root.mainloop()


  • The main method creates the application GUI by initializing a root window, placing tkinter widgets within the window, and binding button widgets to their callback methods.
  • The GUI layout uses a simple grid with 7 rows.
  • Frame widgets are used to manage the layout of sub-sections of the overall interface.
  • You might notice that the select_model_field and select_file_field widgets are Label widgets with white backgrounds. In other words, they are styled to look like input entry fields but are not actual input entry fields. This was done to force users to use the file dialogs when specifying the ASR model directory and audio file path, and to eliminate the need to parse user inputs.
  • All GUI buttons are bound when the interface is initialized, except for the Save To File button.

Step 2.7 - Calling main Method

Finally, call the main method in the last line of the application to start/run the program:


main()

Step 3 - Reviewing the Complete Application

The complete application should be as follows:


import time
import torch
import torchaudio
from transformers import pipeline
import tkinter as tk
import tkinter.scrolledtext as scrolledtext
from tkinter import ttk, filedialog

### CONSTANTS ###
GUI = {
    "title": "Long Inference with wav2vec2 ASR Models",
    "root_width" : 800,
    "root_height" : 500,
    "pad_x": 10,
    "pad_y": 10,
    "input_field_width": 110,
    "textbox_width": 125,
    "textbox_height": 15,
    "select_model_label": "Please select an ASR model",
    "browse_models_button_label": "Browse Models",
    "select_file_label": "Please select an audio file for inference",
    "browse_files_button_label": "Browse Files",
    "run_inference_button_label": "Run",
    "save_to_file_button_label": "Save To File",
    "reset": "Reset",
    "default_notification": "",
    "notification_select_model": "You need to select an ASR model directory",
    "notification_select_audio_file": "You need to select an audio file",
    "notification_select_model_and_audio_file": "You need to select an ASR model and audio file before you can run inference",
    "notification_running_inference": "Running inference...this might take awhile...",
    "notification_finished_inference": "Finished running inference in {} seconds",
    "notification_select_file_for_saving": "You need to select a file to save the transcription",
    "notification_finished_saving": "Finished saving transcription"
}
TGT_SAMPLING_RATE = 16000
CHUNK_LENGTH = 10
STRIDE_START = 3
STRIDE_END = 3

### GLOBALS ###
MODEL = None
AUDIO_FILE = None

### UTILITY METHODS ###
def read_audio_data(file: str) -> tuple[torch.Tensor, int]:
    audio, sampling_rate = torchaudio.load(file, normalize = True)
    return audio, sampling_rate

def resample(waveform: torch.Tensor, orig_sampling_rate: int) -> torch.Tensor:
    transform = torchaudio.transforms.Resample(orig_sampling_rate, TGT_SAMPLING_RATE)
    waveform = transform(waveform)
    return waveform[0]

### CALLBACK METHODS ###
def callback_select_model(e: object, args: list) -> None:
    global MODEL
    
    # Unpack widgets
    select_model_field = args[0]
    notification_label = args[1]

    # Open file dialog
    dir = filedialog.askdirectory()

    # Add `MODEL` to the `select_model_field` widget
    if dir:
        MODEL = dir
        select_model_field.configure({"text": MODEL})
        select_model_field.update()
    else:
        notification_label.configure({"text": GUI["notification_select_model"]})
        notification_label.update()

def callback_select_file(e: object, args: list) -> None:
    global AUDIO_FILE

    # Unpack widgets
    select_file_field = args[0]
    notification_label = args[1]

    if not MODEL:
        notification_label.configure({"text": GUI["notification_select_model"]})
        notification_label.update()
    else:
        # Open file dialog
        file = filedialog.askopenfilename()

        # Add `AUDIO_FILE` filename to the `select_file_field`
        if file:
            AUDIO_FILE = file
            select_file_field.configure({"text": AUDIO_FILE})
            select_file_field.update()
        else:
            notification_label.configure({"text": GUI["notification_select_audio_file"]})
            notification_label.update()            

def callback_run_inference(e: object, args: list) -> None:
    global MODEL
    global AUDIO_FILE

    # Unpack widgets
    textbox = args[0]
    notification_label = args[1]
    save_to_file_button = args[2]
    reset_button = args[3]
    select_model_button = args[4]
    select_file_button = args[5]

    if not MODEL or not AUDIO_FILE:
        notification_label.configure({"text": GUI["notification_select_model_and_audio_file"]})
        notification_label.update()
    else:
        # Disable select model and select file buttons when running inference
        select_model_button.unbind("<Button>")
        select_file_button.unbind("<Button>")

        # Set the input list for the ASR pipline
        pipeline_input = []
        orig_audio, orig_sampling_rate = read_audio_data(AUDIO_FILE)
        resampled_audio = resample(orig_audio, orig_sampling_rate)
        pipeline_input.append({
            "raw": resampled_audio.numpy(),
            "sampling_rate": TGT_SAMPLING_RATE
        })

        # Initialize instance of ASR pipeline
        transcriber = pipeline("automatic-speech-recognition", chunk_length_s = CHUNK_LENGTH, stride_length_s = (STRIDE_START, STRIDE_END), model = MODEL)
        
        # Update notification
        notification_label.configure({"text": GUI["notification_running_inference"]})
        notification_label.update()

        # Set start time
        start_time = time.time()

        # Run inference
        transcription = transcriber(pipeline_input)

        # Set end time
        end_time = time.time()

        # Add transcription to text box
        textbox.insert("1.0", transcription[0]["text"])
        textbox.update()

        # Update notification
        notification_label.configure({"text": GUI["notification_finished_inference"].format(str(int(end_time - start_time)))})
        notification_label.update()        
        
        # Bind `save_to_file_button`
        save_to_file_button.bind("<Button>", lambda e, args = [textbox, notification_label]: callback_save_file(e, args))

def callback_save_file(e: object, args: list) -> None:
    # Unpack widgets
    textbox = args[0]
    notification_label = args[1]
    
    # Ask user to select a file for saving
    save_file = filedialog.asksaveasfilename()

    if save_file:
        # Write transcription to file
        transcription = textbox.get(1.0, tk.END)
        with open(save_file, "w", encoding = "utf8") as handle:
            handle.write(transcription)

        # Update notification
        notification_label.configure({"text": GUI["notification_finished_saving"]})
        notification_label.update()                 
    else:
        notification_label.configure({"text": GUI["select_file_for_saving"]})
        notification_label.update()

def callback_reset(e: object, args: list) -> None:
    global MODEL
    global AUDIO_FILE
    
    # Unpack widgets
    gui_root = args[0]
    select_model_field = args[1]
    select_model_button = args[2]
    select_file_field = args[3]
    select_file_button = args[4]
    textbox = args[5]
    save_to_file_button = args[6]
    notification_label = args[7]

    MODEL = None
    AUDIO_FILE = None

    select_model_field.configure({"text": ""})

    select_file_field.configure({"text": ""})

    textbox.delete("1.0", tk.END)

    notification_label.configure({"text": ""})

    select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
    select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
    save_to_file_button.unbind("<Button>")

    gui_root.update()

def main():
    gui_root = tk.Tk()
    gui_root.title(GUI["title"])
    window_width = GUI["root_width"]
    window_height = GUI["root_height"]

    # Get the screen dimensions
    screen_width = gui_root.winfo_screenwidth()
    screen_height = gui_root.winfo_screenheight()

    # Find the center point
    center_x = int(screen_width/2 - window_width/2)
    center_y = int(screen_height/2 - window_height/2)

    # Set the position of the window to the center of the screen
    gui_root.geometry(f"{window_width}x{window_height}+{center_x}+{center_y}")

    # Not resizable
    gui_root.resizable(False, False)

    # Configure grid
    gui_root.configure(padx = GUI["pad_x"])
    gui_root.columnconfigure(0, weight = 1)

    # Widgets
    select_model_label = ttk.Label(gui_root, text = GUI["select_model_label"])
    select_model_frame = ttk.Frame(gui_root)
    select_model_field = ttk.Label(select_model_frame, text = "")
    select_model_button = ttk.Button(select_model_frame, text = GUI["browse_models_button_label"])
    select_file_label = ttk.Label(gui_root, text = GUI["select_file_label"])
    select_file_frame = ttk.Frame(gui_root)
    select_file_field = ttk.Label(select_file_frame, text = "")
    select_file_button = ttk.Button(select_file_frame, text = GUI["browse_files_button_label"])
    run_inference_button = ttk.Button(gui_root, text = GUI["run_inference_button_label"])
    textbox = scrolledtext.ScrolledText(gui_root, width = GUI["textbox_width"], height = GUI["textbox_height"])
    textbox_buttons_frame = ttk.Frame(gui_root)
    save_to_file_button = ttk.Button(textbox_buttons_frame, text = GUI["save_to_file_button_label"])
    reset_button = ttk.Button(textbox_buttons_frame, text = GUI["reset"])
    notification_label = ttk.Label(gui_root, text = GUI["default_notification"])

    # Place widgets
    # Row 0
    select_model_label.grid(column = 0, row = 0, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    
    # Row 1
    select_model_frame.grid(column = 0, row = 1, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    select_model_field.config(background = "white", width = GUI["input_field_width"])
    select_model_field.pack(side = "left", padx = (0, GUI["pad_x"]))
    select_model_button.pack()
    
    # Row 2
    select_file_label.grid(column = 0, row = 2, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    
    # Row 3
    select_file_frame.grid(column = 0, row = 3, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    select_file_field.config(background = "white", width = GUI["input_field_width"])
    select_file_field.pack(side = "left", padx = (0, GUI["pad_x"]))
    select_file_button.pack()
    
    # Row 4
    run_inference_button.grid(column = 0, row = 4, sticky = tk.W, pady = (GUI["pad_y"], 0))

    # Row 5
    textbox.grid(column = 0, row = 5, columnspan = 1, sticky = tk.W, pady = (GUI["pad_y"], 0))
    
    # Row 6
    textbox_buttons_frame.grid(column = 0, row = 6, sticky = tk.W, pady = (GUI["pad_y"], 0))
    save_to_file_button.pack(side = "left")
    reset_button.pack()

    # Row 7
    notification_label.config(foreground = "blue")
    notification_label.grid(column = 0, row = 7, columnspan = 1, pady = (GUI["pad_y"], 0))

    # Bind buttons
    select_model_button.bind("<Button>", lambda e, args = [select_model_field, notification_label]: callback_select_model(e, args))
    select_file_button.bind("<Button>", lambda e, args = [select_file_field, notification_label]: callback_select_file(e, args))
    run_inference_button.bind("<Button>", lambda e, args = [textbox, notification_label, save_to_file_button, reset_button, select_model_button, select_file_button]: callback_run_inference(e, args))
    reset_button.bind("<Button>", lambda e, args = [gui_root, select_model_field, select_file_button, select_file_field, select_file_button, textbox, save_to_file_button, notification_label]: callback_reset(e, args))

    gui_root.mainloop()

main()


Using the Application

The application workflow is straightforward.


  1. Launch the application.
  2. Click on Browse Models. When the file dialog opens, navigate to the directory containing your ASR model and click Select Folder.
  3. Next, click on Browse Files. When the file dialog open, navigate to the audio file that you want to run inference on and click Open.
  4. Now that you've selected your model and audio file, click Run. This will kick off the inference workflow. The GUI will display the notification" Running inference...this might take while.... Bear in mind that inference might take several minutes depending on the length of the audio sample, as well as the values chosen for CHUNK_LENGTH, STRIDE_START, and STRIDE_END.
  5. Once inference is complete, click on Save To File if you want to save the generated text transcription.
  6. Click on Reset to reset the application for a new inference run.


To reiterate, it can take several minutes to generate a text transcription. For example, the screenshot in the Introduction shows an inference run that took 208 seconds to complete, or ~3.5 minutes. I conducted the run on this Spanish language audio sample from the news channel DW. You will note that the audio sample itself has a duration of 221 seconds. You might consider experimenting with the chunk length and stride values to examine the relationship between the final inference result and the time required to generate the result with respect to your particular ASR model.

Conclusion and Next Steps

I didn't originally plan on writing this third guide on working with wav2vec2. However, I think it is worthwhile to walk through how a practical ASR application can be built. There are any number of follow-up projects that you may wish to undertake following this guide, such as a web version of the application and/or modifying the logic to run live inference instead of waiting to display the complete transcription. As always, I hope you found this guide to be useful and happy building!