paint-brush
Trabajar con Wav2vec2 Parte 1: Ajuste de XLS-R para el reconocimiento automático de vozpor@pictureinthenoise
1,799 lecturas
1,799 lecturas

Trabajar con Wav2vec2 Parte 1: Ajuste de XLS-R para el reconocimiento automático de voz

por Picture in the Noise29m2024/05/04
Read on Terminal Reader

Demasiado Largo; Para Leer

Esta guía explica los pasos para ajustar el modelo wav2vec2 XLS-R de Meta AI para el reconocimiento automático de voz ("ASR"). La guía incluye instrucciones paso a paso sobre cómo construir un Kaggle Notebook que se puede utilizar para ajustar el modelo. El modelo está entrenado en un conjunto de datos español chileno.
featured image - Trabajar con Wav2vec2 Parte 1: Ajuste de XLS-R para el reconocimiento automático de voz
Picture in the Noise HackerNoon profile picture
0-item
1-item

Introducción

Meta AI presentó wav2vec2 XLS-R ("XLS-R") a finales de 2021. XLS-R es un modelo de aprendizaje automático ("ML") para el aprendizaje de representaciones de voz en varios idiomas; y fue entrenado en más de 400.000 horas de audio de voz disponible públicamente en 128 idiomas. Tras su lanzamiento, el modelo representó un salto con respecto al modelo multilingüe XLSR-53 de Meta AI, que se entrenó en aproximadamente 50.000 horas de audio de voz en 53 idiomas.


Esta guía explica los pasos para ajustar XLS-R para el reconocimiento automático de voz ("ASR") usando un Kaggle Notebook . El modelo estará afinado en español de Chile, pero se pueden seguir los pasos generales para afinar XLS-R en los diferentes idiomas que desee.


La ejecución de inferencias en el modelo ajustado se describirá en un tutorial complementario, por lo que esta guía será la primera de dos partes. Decidí crear una guía separada específica de inferencia ya que esta guía de ajuste se volvió un poco larga.


Se supone que tiene experiencia en ML y que comprende los conceptos básicos de ASR. Los principiantes pueden tener dificultades para seguir o comprender los pasos de construcción.

Un poco de historia sobre XLS-R

El modelo wav2vec2 original introducido en 2020 se entrenó previamente con 960 horas de audio de voz del conjunto de datos de Librispeech y ~53 200 horas de audio de voz del conjunto de datos de LibriVox . Tras su lanzamiento, había dos tamaños de modelo disponibles: el modelo BASE con 95 millones de parámetros y el modelo GRANDE con 317 millones de parámetros.


XLS-R, por otro lado, fue entrenado previamente en audio de voz multilingüe a partir de 5 conjuntos de datos:


  • VoxPopuli : Un total de ~372.000 horas de audio de discursos en 23 idiomas europeos de discurso parlamentario del Parlamento Europeo.
  • Librispeech multilingüe : un total de ~50 000 horas de audio de voz en ocho idiomas europeos, con la mayoría (~44 000 horas) de datos de audio en inglés.
  • CommonVoice : un total de ~7000 horas de audio de voz en 60 idiomas.
  • VoxLingua107 : un total de ~6600 horas de audio de voz en 107 idiomas según el contenido de YouTube.
  • BABEL : Un total de ~1100 horas de audio de voz en 17 idiomas africanos y asiáticos basado en conversaciones telefónicas.


Hay 3 modelos XLS-R: XLS-R (0.3B) con 300 millones de parámetros, XLS-R (1B) con mil millones de parámetros y XLS-R (2B) con 2 mil millones de parámetros. Esta guía utilizará el modelo XLS-R (0.3B).

Acercarse

Hay algunos artículos excelentes sobre cómo ajustar los modelos wav2vev2 , y quizás este sea una especie de "estándar de oro". Por supuesto, el enfoque general aquí imita lo que encontrará en otras guías. Vas a:


  • Cargue un conjunto de datos de entrenamiento de datos de audio y transcripciones de texto asociadas.
  • Cree un vocabulario a partir de las transcripciones de texto en el conjunto de datos.
  • Inicialice un procesador wav2vec2 que extraerá características de los datos de entrada, así como también convertirá transcripciones de texto en secuencias de etiquetas.
  • Ajuste wav2vec2 XLS-R en los datos de entrada procesados.


Sin embargo, existen tres diferencias clave entre esta guía y otras:


  1. La guía no proporciona tanta discusión "en línea" sobre conceptos relevantes de ML y ASR.
    • Si bien cada subsección sobre celdas individuales del cuaderno incluirá detalles sobre el uso/propósito de la celda en particular, se supone que usted tiene experiencia en ML y que comprende los conceptos básicos de ASR.
  2. El Kaggle Notebook que creará organiza los métodos de utilidad en celdas de nivel superior.
    • Mientras que muchos cuadernos de ajuste tienden a tener una especie de diseño tipo "flujo de conciencia", elegí organizar todos los métodos de utilidad juntos. Si es nuevo en wav2vec2, este enfoque puede resultarle confuso. Sin embargo, para reiterar, hago todo lo posible por ser explícito al explicar el propósito de cada celda en la subsección dedicada a cada celda. Si recién está aprendiendo sobre wav2vec2, podría beneficiarse de echar un vistazo rápido a mi artículo de HackerNoon wav2vec2 para el reconocimiento automático de voz en inglés sencillo .
  3. Esta guía describe los pasos para el ajuste fino únicamente.
    • Como se mencionó en la Introducción , opté por crear una guía complementaria separada sobre cómo ejecutar la inferencia en el modelo XLS-R ajustado que usted generará. Esto se hizo para evitar que esta guía se volviera excesivamente larga.

Requisitos previos y antes de comenzar

Para completar la guía, necesitarás tener:


  • Una cuenta de Kaggle existente. Si no tiene una cuenta de Kaggle existente, debe crear una.
  • Una cuenta existente de Weights and Biases ("WandB") . Si no tiene una cuenta de Weights and Biases, debe crear una.
  • Una clave API de WandB. Si no tiene una clave API de WandB, siga los pasos aquí .
  • Conocimientos intermedios de Python.
  • Conocimiento intermedio de trabajo con Kaggle Notebooks.
  • Conocimiento intermedio de conceptos de ML.
  • Conocimientos básicos de conceptos ASR.


Antes de comenzar a crear el cuaderno, puede resultar útil revisar las dos subsecciones que aparecen directamente a continuación. Ellos describen:


  1. El conjunto de datos de entrenamiento.
  2. La métrica de tasa de error de palabras ("WER") utilizada durante el entrenamiento.

Conjunto de datos de entrenamiento

Como se mencionó en la Introducción , el modelo XLS-R estará afinado en español de Chile. El conjunto de datos específico es el conjunto de datos del habla del español chileno desarrollado por Guevara-Rukoz et al. Está disponible para descargar en OpenSLR . El conjunto de datos consta de dos subconjuntos de datos: (1) 2.636 grabaciones de audio de hablantes chilenos y (2) 1.738 grabaciones de audio de hablantes chilenas.


Cada subconjunto de datos incluye un archivo de índice line_index.tsv . Cada línea de cada archivo de índice contiene un par de nombres de archivo de audio y una transcripción del audio en el archivo asociado, por ejemplo:


 clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente


He subido el conjunto de datos de voz en español de Chile a Kaggle para mayor comodidad. Hay un conjunto de datos de Kaggle para las grabaciones de hablantes chilenos y un conjunto de datos de Kaggle para las grabaciones de hablantes chilenas . Estos conjuntos de datos de Kaggle se agregarán al Kaggle Notebook que creará siguiendo los pasos de esta guía.

Tasa de errores de palabras (WER)

WER es una métrica que se puede utilizar para medir el rendimiento de los modelos de reconocimiento automático de voz. WER proporciona un mecanismo para medir qué tan cerca está una predicción de texto de una referencia de texto. WER logra esto registrando errores de 3 tipos:


  • sustituciones ( S ): Se registra un error de sustitución cuando la predicción contiene una palabra que es diferente de la palabra análoga en la referencia. Por ejemplo, esto ocurre cuando la predicción escribe mal una palabra en la referencia.

  • eliminaciones ( D ): Se registra un error de eliminación cuando la predicción contiene una palabra que no está presente en la referencia.

  • inserciones ( I ): Se registra un error de inserción cuando la predicción no contiene una palabra que esté presente en la referencia.


Obviamente, WER funciona a nivel de palabras. La fórmula para la métrica WER es la siguiente:


 WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference


Un ejemplo sencillo de WER en español es el siguiente:


 prediction: "Él está saliendo." reference: "Él está saltando."


Una tabla ayuda a visualizar los errores en la predicción:

TEXTO

PALABRA 1

PALABRA 2

PALABRA 3

predicción

Él

esta

saliendo

referencia

Él

esta

saltando


correcto

correcto

sustitución

La predicción contiene 1 error de sustitución, 0 errores de eliminación y 0 errores de inserción. Entonces, el WER para este ejemplo es:


 WER = 1 + 0 + 0 / 3 = 1/3 = 0.33


Debería ser obvio que la tasa de errores de palabras no necesariamente nos dice qué errores específicos existen. En el ejemplo anterior, WER identifica que la PALABRA 3 contiene un error en el texto predicho, pero no nos dice que los caracteres i y e están equivocados en la predicción. Se pueden utilizar otras métricas, como la tasa de error de caracteres ("CER"), para un análisis de errores más preciso.

Construyendo el cuaderno de ajuste

Ahora está listo para comenzar a crear el cuaderno de ajuste.


  • Los pasos 1 y 2 lo guiarán a través de la configuración de su entorno Kaggle Notebook.
  • El paso 3 lo guía a través de la construcción del cuaderno. Contiene 32 subpasos que representan las 32 celdas del cuaderno de ajuste.
  • El paso 4 lo guía a través de la ejecución del cuaderno, el seguimiento del entrenamiento y el guardado del modelo.

Paso 1: obtenga su clave API de WandB

Su Kaggle Notebook debe estar configurado para enviar datos de ejecución de entrenamiento a WandB usando su clave API de WandB. Para hacer eso, necesitas copiarlo.


  1. Inicie sesión en WandB en www.wandb.com .
  2. Navegue a www.wandb.ai/authorize .
  3. Copie su clave API para usarla en el siguiente paso.

Paso 2: configurar su entorno Kaggle

Paso 2.1: Crear un nuevo cuaderno Kaggle


  1. Inicie sesión en Kaggle.
  2. Crea un nuevo cuaderno Kaggle.
  3. Por supuesto, el nombre del cuaderno se puede cambiar según se desee. Esta guía utiliza el nombre del cuaderno xls-r-300m-chilean-spanish-asr .

Paso 2.2: Configuración de su clave API de WandB

Se utilizará un Kaggle Secret para almacenar de forma segura su clave API de WandB.


  1. Haga clic en Complementos en el menú principal de Kaggle Notebook.
  2. Seleccione Secreto en el menú emergente.
  3. Ingrese la etiqueta WANDB_API_KEY en el campo Etiqueta e ingrese su clave API de WandB para el valor.
  4. Asegúrese de que la casilla de verificación Adjunto a la izquierda del campo de etiqueta WANDB_API_KEY esté marcada.
  5. Haga clic en Listo .

Paso 2.3: Agregar los conjuntos de datos de entrenamiento

El conjunto de datos del habla en español de Chile se cargó en Kaggle como 2 conjuntos de datos distintos:


Agregue ambos conjuntos de datos a su Kaggle Notebook.

Paso 3: creación del cuaderno de ajuste

Los siguientes 32 subpasos construyen cada una de las 32 celdas del cuaderno de ajuste en orden.

Paso 3.1 - CÉLULA 1: Instalación de paquetes

La primera celda del cuaderno de ajuste instala dependencias. Establezca la primera celda en:


 ### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer


  • La primera línea actualiza el paquete torchaudio a la última versión. torchaudio se utilizará para cargar archivos de audio y volver a muestrear datos de audio.
  • La segunda línea instala el paquete jiwer que se requiere para usar el método load_metric de la biblioteca HuggingFace Datasets que se usará más adelante.

Paso 3.2 - CÉLULA 2: Importación de paquetes de Python

La segunda celda importa los paquetes Python requeridos. Establezca la segunda celda en:


 ### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer


  • Probablemente ya esté familiarizado con la mayoría de estos paquetes. Su uso en el cuaderno se explicará a medida que se construyan las celdas posteriores.
  • Vale la pena mencionar que la biblioteca transformers HuggingFace y las clases Wav2Vec2* asociadas proporcionan la columna vertebral de la funcionalidad utilizada para el ajuste fino.

Paso 3.3 - CELDA 3: Carga de la métrica WER

La tercera celda importa la métrica de evaluación WER de HuggingFace. Establezca la tercera celda en:


 ### CELL 3: Load WER metric ### wer_metric = load_metric("wer")


  • Como se mencionó anteriormente, WER se utilizará para medir el desempeño del modelo en datos de evaluación/reserva.

Paso 3.4 - CÉLULA 4: Iniciar sesión en WandB

La cuarta celda recupera su secreto WANDB_API_KEY que se configuró en el Paso 2.2 . Establezca la cuarta celda en:


 ### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)


  • La clave API se utiliza para configurar Kaggle Notebook de modo que los datos de la ejecución de entrenamiento se envíen a WandB.

Paso 3.5 - CÉLULA 5: Configuración de constantes

La quinta celda establece constantes que se utilizarán en todo el cuaderno. Establezca la quinta celda en:


 ### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800


  • El cuaderno no muestra todas las constantes imaginables en esta celda. Algunos valores que podrían representarse mediante constantes se han dejado en línea.
  • El uso de muchas de las constantes anteriores debería ser evidente. Para aquellos que no lo son, su uso se explicará en los siguientes subpasos.

Paso 3.6 - CÉLULA 6: Métodos de utilidad para leer archivos de índice, limpiar texto y crear vocabulario

La sexta celda define métodos de utilidad para leer los archivos de índice del conjunto de datos (consulte la subsección Conjunto de datos de entrenamiento anterior), así como para limpiar el texto de transcripción y crear el vocabulario. Establezca la sexta celda en:


 ### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list


  • El método read_index_file_data lee un archivo de índice del conjunto de datos line_index.tsv y produce una lista de listas con nombres de archivos de audio y datos de transcripción, por ejemplo:


 [ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]


  • El método truncate_training_dataset trunca los datos de un archivo de índice de lista utilizando la constante NUM_LOAD_FROM_EACH_SET establecida en el Paso 3.5 . Específicamente, la constante NUM_LOAD_FROM_EACH_SET se usa para especificar la cantidad de muestras de audio que deben cargarse desde cada conjunto de datos. Para los fines de esta guía, el número se establece en 1600 lo que significa que eventualmente se cargarán un total de 3200 muestras de audio. Para cargar todas las muestras, establezca NUM_LOAD_FROM_EACH_SET en el valor de cadena all .
  • El método clean_text se utiliza para eliminar de cada transcripción de texto los caracteres especificados por la expresión regular asignada a SPECIAL_CHARS en el Paso 3.5 . Estos caracteres, incluida la puntuación, se pueden eliminar ya que no proporcionan ningún valor semántico al entrenar el modelo para aprender asignaciones entre funciones de audio y transcripciones de texto.
  • El método create_vocab crea un vocabulario a partir de transcripciones de texto limpio. Simplemente, extrae todos los caracteres únicos del conjunto de transcripciones de texto limpias. Verá un ejemplo del vocabulario generado en el Paso 3.14 .

Paso 3.7 - CÉLULA 7: Métodos de utilidad para cargar y remuestrear datos de audio

La séptima celda define métodos de utilidad que utilizan torchaudio para cargar y volver a muestrear datos de audio. Establezca la séptima celda en:


 ### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]


  • El método read_audio_data carga un archivo de audio específico y devuelve una matriz multidimensional torch.Tensor de los datos de audio junto con la frecuencia de muestreo del audio. Todos los archivos de audio de los datos de entrenamiento tienen una frecuencia de muestreo de 48000 Hz. Esta frecuencia de muestreo "original" es capturada por la constante ORIG_SAMPLING_RATE en el Paso 3.5 .
  • El método resample se utiliza para reducir la resolución de datos de audio desde una frecuencia de muestreo de 48000 a 16000 . wav2vec2 está previamente entrenado en audio muestreado a 16000 Hz. En consecuencia, cualquier audio utilizado para el ajuste debe tener la misma frecuencia de muestreo. En este caso, los ejemplos de audio deben reducirse de 48000 Hz a 16000 Hz. 16000 Hz son capturados por la constante TGT_SAMPLING_RATE en el Paso 3.5 .

Paso 3.8 - CÉLULA 8: Métodos de utilidad para preparar datos para la capacitación

La octava celda define métodos de utilidad que procesan los datos de audio y transcripción. Establezca la octava celda en:


 ### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding


  • El método process_speech_audio devuelve los valores de entrada de una muestra de entrenamiento proporcionada.
  • El método process_target_text codifica cada transcripción de texto como una lista de etiquetas, es decir, una lista de índices que hacen referencia a caracteres del vocabulario. Verá una codificación de muestra en el Paso 3.15 .

Paso 3.9 - CÉLULA 9: Método de utilidad para calcular la tasa de error de palabras

La novena celda es la celda del método de utilidad final y contiene el método para calcular la tasa de error de palabras entre una transcripción de referencia y una transcripción prevista. Establezca la novena celda en:


 ### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}

Paso 3.10 - CÉLULA 10: Lectura de datos de entrenamiento

La décima celda lee los archivos de índice de datos de entrenamiento para las grabaciones de hablantes masculinos y las grabaciones de hablantes femeninas utilizando el método read_index_file_data definido en el Paso 3.6 . Establezca la décima celda en:


 ### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")


  • Como se ve, en este punto los datos de entrenamiento se gestionan en dos listas específicas de género. Los datos se combinarán en el Paso 3.12 después del truncamiento.

Paso 3.11 - CÉLULA 11: Truncar datos de entrenamiento

La undécima celda trunca las listas de datos de entrenamiento utilizando el método truncate_training_dataset definido en el Paso 3.6 . Establezca la undécima celda en:


 ### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)


  • Como recordatorio, la constante NUM_LOAD_FROM_EACH_SET establecida en el Paso 3.5 define la cantidad de muestras que se conservarán de cada conjunto de datos. La constante se establece en 1600 en esta guía para un total de 3200 muestras.

Paso 3.12 - CÉLULA 12: Combinación de datos de muestras de entrenamiento

La duodécima celda combina las listas de datos de entrenamiento truncadas. Establezca la duodécima celda en:


 ### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl

Paso 3.13 - CÉLULA 13: Prueba de transcripción de limpieza

La decimotercera celda itera sobre cada muestra de datos de entrenamiento y limpia el texto de transcripción asociado utilizando el método clean_text definido en el Paso 3.6 . Establezca la decimotercera celda en:


 for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])

Paso 3.14 - CÉLULA 14: Crear el vocabulario

La decimocuarta celda crea un vocabulario utilizando las transcripciones limpias del paso anterior y el método create_vocab definido en el Paso 3.6 . Establezca la decimocuarta celda en:


 ### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}


  • El vocabulario se almacena como un diccionario con caracteres como claves e índices de vocabulario como valores.

  • Puede imprimir vocab_dict que debería producir el siguiente resultado:


 {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}

Paso 3.15 - CÉLULA 15: Agregar delimitador de palabras al vocabulario

La decimoquinta celda agrega el carácter delimitador de palabras | al vocabulario. Establezca la decimoquinta celda en:


 ### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)


  • El carácter delimitador de palabras se utiliza al tokenizar transcripciones de texto como una lista de etiquetas. Específicamente, se usa para definir el final de una palabra y se usa al inicializar la clase Wav2Vec2CTCTokenizer , como se verá en el Paso 3.17 .

  • Por ejemplo, la siguiente lista codifica no te entiendo nada usando el vocabulario del Paso 3.14 :


 # Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}


  • Una pregunta que naturalmente podría surgir es: "¿Por qué es necesario definir un carácter delimitador de palabra?" Por ejemplo, el final de las palabras escritas en inglés y español están marcados con espacios en blanco, por lo que debería ser sencillo utilizar el carácter de espacio como delimitador de palabras. Recuerde que el inglés y el español son sólo dos idiomas entre miles; y no todos los idiomas escritos utilizan un espacio para marcar los límites de las palabras.

Paso 3.16 - CÉLULA 16: Exportación de vocabulario

La decimosexta celda vuelca el vocabulario en un archivo. Establezca la decimosexta celda en:


 ### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)


  • El archivo de vocabulario se utilizará en el siguiente paso, Paso 3.17 , para inicializar la clase Wav2Vec2CTCTokenizer .

Paso 3.17 - CÉLULA 17: Inicializar el tokenizador

La decimoséptima celda inicializa una instancia de Wav2Vec2CTCTokenizer . Establezca la decimoséptima celda en:


 ### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )


  • El tokenizador se utiliza para codificar transcripciones de texto y decodificar una lista de etiquetas en texto.

  • Tenga en cuenta que el tokenizer se inicializa con [UNK] asignado a unk_token y [PAD] asignado a pad_token ; el primero se usa para representar tokens desconocidos en transcripciones de texto y el segundo se usa para rellenar transcripciones al crear lotes de transcripciones con diferentes longitudes. El tokenizador agregará estos dos valores al vocabulario.

  • La inicialización del tokenizador en este paso también agregará dos tokens adicionales al vocabulario, a saber, <s> y /</s> , que se utilizan para demarcar el principio y el final de las oraciones, respectivamente.

  • | se asigna explícitamente a word_delimiter_token en este paso para reflejar que el símbolo de canalización se usará para demarcar el final de las palabras de acuerdo con nuestra adición del carácter al vocabulario en el Paso 3.15 . El | El símbolo es el valor predeterminado para word_delimiter_token . Por lo tanto, no era necesario establecerlo explícitamente, pero se hizo en aras de la claridad.

  • De manera similar a word_delimiter_token , se asigna explícitamente un único espacio a replace_word_delimiter_char lo que refleja que el símbolo de canalización | se utilizará para reemplazar caracteres de espacio en blanco en las transcripciones de texto. El espacio en blanco es el valor predeterminado para replace_word_delimiter_char . Por lo tanto, tampoco era necesario establecerlo explícitamente, pero se hizo en aras de la claridad.

  • Puede imprimir el vocabulario completo del tokenizador llamando al método get_vocab() en tokenizer .


 vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}

Paso 3.18 - CÉLULA 18: Inicialización del extractor de funciones

La decimoctava celda inicializa una instancia de Wav2Vec2FeatureExtractor . Establezca la decimoctava celda en:


 ### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )


  • El extractor de funciones se utiliza para extraer funciones de los datos de entrada que, por supuesto, son datos de audio en este caso de uso. Cargará los datos de audio para cada muestra de datos de entrenamiento en el Paso 3.20 .
  • Los valores de los parámetros pasados al inicializador Wav2Vec2FeatureExtractor son todos valores predeterminados, con la excepción de return_attention_mask que por defecto es False . Los valores predeterminados se muestran/aprueban para mayor claridad.
  • El parámetro feature_size especifica el tamaño de dimensión de las funciones de entrada (es decir, funciones de datos de audio). El valor predeterminado de este parámetro es 1 .
  • sampling_rate le dice al extractor de funciones la frecuencia de muestreo a la que se deben digitalizar los datos de audio. Como se analizó en el Paso 3.7 , wav2vec2 está preentrenado en audio muestreado a 16000 Hz y, por lo tanto, 16000 es el valor predeterminado para este parámetro.
  • El parámetro padding_value especifica el valor que se utiliza al rellenar datos de audio, según sea necesario al agrupar muestras de audio de diferentes longitudes. El valor predeterminado es 0.0 .
  • do_normalize se utiliza para especificar si los datos de entrada deben transformarse a una distribución normal estándar. El valor por defecto es True . La documentación de la clase Wav2Vec2FeatureExtractor señala que "[la normalización] puede ayudar a mejorar significativamente el rendimiento de algunos modelos".
  • Los parámetros return_attention_mask especifican si se debe pasar la máscara de atención o no. El valor se establece en True para este caso de uso.

Paso 3.19 - CÉLULA 19: Inicializando el procesador

La decimonovena celda inicializa una instancia de Wav2Vec2Processor . Establezca la decimonovena celda en:


 ### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)


  • La clase Wav2Vec2Processor combina tokenizer y feature_extractor del paso 3.17 y 3.18 respectivamente en un solo procesador.

  • Tenga en cuenta que la configuración del procesador se puede guardar llamando al método save_pretrained en la instancia de la clase Wav2Vec2Processor .


 processor.save_pretrained(OUTPUT_DIR_PATH)

Paso 3.20 - CÉLULA 20: Cargando datos de audio

La vigésima celda carga cada archivo de audio especificado en la lista all_training_samples . Establezca la vigésima celda en:


 ### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })


  • Los datos de audio se devuelven como un torch.Tensor y se almacenan en all_input_data como una lista de diccionarios. Cada diccionario contiene los datos de audio de una muestra particular, junto con la transcripción del texto del audio.
  • Tenga en cuenta que el método read_audio_data también devuelve la frecuencia de muestreo de los datos de audio. Como sabemos que la frecuencia de muestreo es 48000 Hz para todos los archivos de audio en este caso de uso, la frecuencia de muestreo se ignora en este paso.

Paso 3.21 - CÉLULA 21: Convertir all_input_data en un Pandas DataFrame

La vigésima primera celda convierte la lista all_input_data en un Pandas DataFrame para facilitar la manipulación de los datos. Establezca la vigésima primera celda en:


 ### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)

Paso 3.22 - CÉLULA 22: Procesamiento de datos de audio y transcripciones de texto

La vigésima segunda celda utiliza el processor inicializado en el Paso 3.19 para extraer características de cada muestra de datos de audio y codificar cada transcripción de texto como una lista de etiquetas. Establezca la vigésima segunda celda en:


 ### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))

Paso 3.23 - CÉLULA 23: División de datos de entrada en conjuntos de datos de entrenamiento y validación

La vigésima tercera celda divide el marco de datos all_input_data_df en conjuntos de datos de entrenamiento y evaluación (validación) utilizando la constante SPLIT_PCT del paso 3.5 . Establezca la vigésima tercera celda en:


 ### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]


  • El valor SPLIT_PCT es 0.10 en esta guía, lo que significa que el 10 % de todos los datos de entrada se reservarán para evaluación y el 90 % de los datos se utilizarán para capacitación/ajuste.
  • Dado que hay un total de 3200 muestras de entrenamiento, se utilizarán 320 muestras para la evaluación y las 2880 muestras restantes se utilizarán para ajustar el modelo.

Paso 3.24 - CÉLULA 24: Conversión de conjuntos de datos de entrenamiento y validación en objetos Dataset

La vigésima cuarta celda convierte los DataFrames train_data_df y valid_data_df en objetos Dataset . Establezca la vigésima cuarta celda en:


 ### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)


  • Los objetos Dataset son consumidos por instancias de la clase HuggingFace Trainer , como verá en el Paso 3.30 .

  • Estos objetos contienen metadatos sobre el conjunto de datos, así como sobre el conjunto de datos en sí.

  • Puede imprimir train_data y valid_data para ver los metadatos de ambos objetos Dataset .


 print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })

Paso 3.25 - CÉLULA 25: Inicialización del modelo previamente entrenado

La vigésima quinta celda inicializa el modelo XLS-R (0.3) previamente entrenado. Establezca la vigésima quinta celda en:


 ### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )


  • El método from_pretrained llamado en Wav2Vec2ForCTC especifica que queremos cargar los pesos previamente entrenados para el modelo especificado.
  • La constante MODEL se especificó en el Paso 3.5 y se configuró en facebook/wav2vec2-xls-r-300m reflejando el modelo XLS-R (0.3).
  • El parámetro ctc_loss_reduction especifica el tipo de reducción que se aplicará a la salida de la función de pérdida de Clasificación Temporal Conexionista ("CTC"). La pérdida de CTC se utiliza para calcular la pérdida entre una entrada continua, en este caso datos de audio, y una secuencia de destino, en este caso transcripciones de texto. Al establecer el valor en mean , las pérdidas de producción de un lote de insumos se dividirán por las longitudes objetivo. Luego se calcula la media del lote y la reducción se aplica a los valores de pérdida.
  • pad_token_id especifica el token que se utilizará para el relleno al realizar el procesamiento por lotes. Se establece en el ID [PAD] establecido al inicializar el tokenizador en el Paso 3.17 .
  • El parámetro vocab_size define el tamaño del vocabulario del modelo. Es el tamaño del vocabulario después de la inicialización del tokenizador en el Paso 3.17 y refleja el número de nodos de la capa de salida de la parte directa de la red.

Paso 3.26 - CÉLULA 26: Pesos extractores de función de congelación

La vigésima sexta celda congela los pesos previamente entrenados del extractor de características. Establezca la vigésima sexta celda en:


 ### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()

Paso 3.27 - CÉLULA 27: Establecer argumentos de entrenamiento

La vigésima séptima celda inicializa los argumentos de entrenamiento que se pasarán a una instancia Trainer . Establezca la vigésima séptima celda en:


 ### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )


  • La clase TrainingArguments acepta más de 100 parámetros .
  • El parámetro save_safetensors cuando False especifica que el modelo ajustado debe guardarse en un archivo pickle en lugar de usar el formato safetensors .
  • El parámetro group_by_length cuando es True indica que se deben agrupar muestras de aproximadamente la misma longitud. Esto minimiza el acolchado y mejora la eficiencia del entrenamiento.
  • per_device_train_batch_size establece el número de muestras por minilote de entrenamiento. Este parámetro se establece en 18 mediante la constante TRAIN_BATCH_SIZE asignada en el Paso 3.5 . Esto implica 160 pasos por época.
  • per_device_eval_batch_size establece el número de muestras por minilote de evaluación (reserva). Este parámetro se establece en 10 mediante la constante EVAL_BATCH_SIZE asignada en el Paso 3.5 .
  • num_train_epochs establece el número de épocas de entrenamiento. Este parámetro se establece en 30 mediante la constante TRAIN_EPOCHS asignada en el Paso 3.5 . Esto implica 4.800 pasos totales durante el entrenamiento.
  • El parámetro gradient_checkpointing cuando True ayuda a ahorrar memoria al controlar los cálculos de gradiente, pero da como resultado pasos hacia atrás más lentos.
  • El parámetro evaluation_strategy cuando se establece en steps significa que la evaluación se realizará y registrará durante el entrenamiento en un intervalo especificado por el parámetro eval_steps .
  • El parámetro logging_strategy cuando se establece en steps significa que las estadísticas de ejecución de entrenamiento se registrarán en un intervalo especificado por el parámetro logging_steps .
  • El parámetro save_strategy cuando se establece en steps significa que se guardará un punto de control del modelo ajustado en un intervalo especificado por el parámetro save_steps .
  • eval_steps establece el número de pasos entre evaluaciones de datos reservados. Este parámetro se establece en 100 mediante la constante EVAL_STEPS asignada en el Paso 3.5 .
  • save_steps establece el número de pasos después de los cuales se guarda un punto de control del modelo ajustado. Este parámetro se establece en 3200 mediante la constante SAVE_STEPS asignada en el Paso 3.5 .
  • logging_steps establece el número de pasos entre registros de estadísticas de ejecución de entrenamiento. Este parámetro se establece en 100 mediante la constante LOGGING_STEPS asignada en el Paso 3.5 .
  • El parámetro learning_rate establece la tasa de aprendizaje inicial. Este parámetro se establece en 1e-4 mediante la constante LEARNING_RATE asignada en el Paso 3.5 .
  • El parámetro warmup_steps establece el número de pasos para calentar linealmente la tasa de aprendizaje desde 0 hasta el valor establecido por learning_rate . Este parámetro se establece en 800 mediante la constante WARMUP_STEPS asignada en el Paso 3.5 .

Paso 3.28 - CÉLULA 28: Definición de la lógica del recopilador de datos

La vigésima octava celda define la lógica para rellenar dinámicamente las secuencias de entrada y de destino. Establezca la vigésima octava celda en:


 ### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch


  • Los pares de etiquetas de entrada de capacitación y evaluación se pasan en minilotes a la instancia Trainer que se inicializará momentáneamente en el Paso 3.30 . Dado que las secuencias de entrada y las secuencias de etiquetas varían en longitud en cada minilote, algunas secuencias deben rellenarse para que todas tengan la misma longitud.
  • La clase DataCollatorCTCWithPadding rellena dinámicamente datos de mini lotes. El parámetro padding cuando se establece en True especifica que las secuencias de funciones de entrada de audio y las secuencias de etiquetas más cortas deben tener la misma longitud que la secuencia más larga en un mini lote.
  • Las funciones de entrada de audio se rellenan con el valor 0.0 establecido al inicializar el extractor de funciones en el Paso 3.18 .
  • Las entradas de etiquetas primero se rellenan con el valor de relleno establecido al inicializar el tokenizador en el Paso 3.17 . Estos valores se reemplazan por -100 para que estas etiquetas se ignoren al calcular la métrica WER.

Paso 3.29 - CÉLULA 29: Inicialización de la instancia del recopilador de datos

La vigésima novena celda inicializa una instancia del recopilador de datos definido en el paso anterior. Establezca la vigésima novena celda en:


 ### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)

Paso 3.30 - CÉLULA 30: Inicializando el Entrenador

La trigésima celda inicializa una instancia de la clase Trainer . Establezca la trigésima celda en:


 ### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )


  • Como se ve, la clase Trainer se inicializa con:
    • El model previamente entrenado se inicializó en el Paso 3.25 .
    • El recopilador de datos se inicializó en el Paso 3.29 .
    • Los argumentos de entrenamiento se inicializaron en el Paso 3.27 .
    • El método de evaluación WER definido en el Paso 3.9 .
    • El objeto train_data Dataset del paso 3.24 .
    • El objeto valid_data Dataset del paso 3.24 .
  • El parámetro tokenizer se asigna a processor.feature_extractor y funciona con data_collator para rellenar automáticamente las entradas hasta la entrada de longitud máxima de cada mini lote.

Paso 3.31 - CÉLULA 31: Ajuste del modelo

La trigésima primera celda llama al método train en la instancia de la clase Trainer para ajustar el modelo. Establezca la trigésima primera celda en:


 ### CELL 31: Finetune the model ### trainer.train()

Paso 3.32 - CÉLULA 32: Guarde el modelo ajustado

La celda trigésimo segunda es la última celda del cuaderno. Guarda el modelo ajustado llamando al método save_model en la instancia Trainer . Establezca la celda de treinta segundos en:


 ### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)

Paso 4: entrenar y guardar el modelo

Paso 4.1: entrenar el modelo

Ahora que se han construido todas las celdas del portátil, es hora de empezar a realizar ajustes.


  1. Configure el portátil Kaggle para que se ejecute con el acelerador NVIDIA GPU P100 .

  2. Confirme el cuaderno en Kaggle.

  3. Supervise los datos de la carrera de entrenamiento iniciando sesión en su cuenta WandB y localizando la carrera asociada.


El entrenamiento de más de 30 épocas debería llevar aproximadamente 5 horas con el acelerador NVIDIA GPU P100. El WER en los datos de reserva debería caer a ~0,15 al final del entrenamiento. No es un resultado de última generación, pero el modelo ajustado sigue siendo suficientemente útil para muchas aplicaciones.

Paso 4.2 - Guardar el modelo

El modelo ajustado se enviará al directorio de Kaggle especificado por la constante OUTPUT_DIR_PATH especificada en el Paso 3.5 . La salida del modelo debe incluir los siguientes archivos:


 pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin


Estos archivos se pueden descargar localmente. Además, puede crear un nuevo modelo Kaggle utilizando los archivos del modelo. El modelo Kaggle se utilizará con la guía de inferencia complementaria para ejecutar inferencias en el modelo ajustado.


  1. Inicie sesión en su cuenta de Kaggle. Haga clic en Modelos > Nuevo modelo .
  2. Agregue un título para su modelo ajustado en el campo Título del modelo .
  3. Haga clic en Crear modelo .
  4. Haga clic en Ir a la página de detalles del modelo .
  5. Haga clic en Agregar nueva variación en Variaciones del modelo .
  6. Seleccione Transformers en el menú de selección de Framework .
  7. Haga clic en Agregar nueva variación .
  8. Arrastra y suelta tus archivos de modelo ajustados en la ventana Cargar datos . Alternativamente, haga clic en el botón Examinar archivos para abrir una ventana del explorador de archivos y seleccionar los archivos de su modelo ajustado.
  9. Una vez que los archivos se hayan cargado en Kaggle, haga clic en Crear para crear el modelo de Kaggle .

Conclusión

¡Felicitaciones por ajustar wav2vec2 XLS-R! Recuerde que puede utilizar estos pasos generales para ajustar el modelo en otros idiomas que desee. Realizar inferencias sobre el modelo ajustado generado en esta guía es bastante sencillo. Los pasos de inferencia se describirán en una guía complementaria a esta. Busque mi nombre de usuario de HackerNoon para encontrar la guía complementaria.