Meta AI introduziu wav2vec2 XLS-R ("XLS-R") no final de 2021. XLS-R é um modelo de aprendizado de máquina ("ML") para aprendizagem de representações de fala multilíngues; e foi treinado em mais de 400.000 horas de áudio de fala disponível publicamente em 128 idiomas. Após seu lançamento, o modelo representou um salto em relação ao modelo multilíngue XLSR-53 da Meta AI, que foi treinado em aproximadamente 50.000 horas de áudio de fala em 53 idiomas.
Este guia explica as etapas para ajustar o XLS-R para reconhecimento automático de fala ("ASR") usando um Kaggle Notebook . O modelo será ajustado em espanhol chileno, mas os passos gerais podem ser seguidos para ajustar o XLS-R nos diferentes idiomas que você desejar.
A execução da inferência no modelo ajustado será descrita em um tutorial complementar, tornando este guia a primeira de duas partes. Decidi criar um guia específico de inferência separado, pois este guia de ajuste fino ficou um pouco longo.
Presume-se que você tenha experiência em ML e entenda os conceitos básicos de ASR. Os iniciantes podem ter dificuldade em seguir/compreender as etapas de construção.
O modelo wav2vec2 original introduzido em 2020 foi pré-treinado em 960 horas de áudio de fala do conjunto de dados Librispeech e aproximadamente 53.200 horas de áudio de fala do conjunto de dados LibriVox . No seu lançamento, dois tamanhos de modelo estavam disponíveis: o modelo BASE com 95 milhões de parâmetros e o modelo LARGE com 317 milhões de parâmetros.
O XLS-R, por outro lado, foi pré-treinado em áudio de fala multilíngue a partir de 5 conjuntos de dados:
Existem 3 modelos XLS-R: XLS-R (0,3B) com 300 milhões de parâmetros, XLS-R (1B) com 1 bilhão de parâmetros e XLS-R (2B) com 2 bilhões de parâmetros. Este guia usará o modelo XLS-R (0,3B).
Existem alguns ótimos artigos sobre como ajustar os modelos wav2vev2 , talvez este seja uma espécie de "padrão ouro". É claro que a abordagem geral aqui imita o que você encontrará em outros guias. Você irá:
No entanto, existem três diferenças principais entre este guia e outros:
Para completar o guia, você precisará ter:
Antes de começar a construir o notebook, pode ser útil revisar as duas subseções diretamente abaixo. Eles descrevem:
Conforme mencionado na Introdução , o modelo XLS-R será aprimorado no espanhol chileno. O conjunto de dados específico é o conjunto de dados de fala do espanhol chileno desenvolvido por Guevara-Rukoz et al. Ele está disponível para download no OpenSLR . O conjunto de dados consiste em dois subconjuntos de dados: (1) 2.636 gravações de áudio de falantes chilenos do sexo masculino e (2) 1.738 gravações de áudio de falantes chilenos do sexo feminino.
Cada subconjunto de dados inclui um arquivo de índice line_index.tsv
. Cada linha de cada arquivo de índice contém um par de nome de arquivo de áudio e uma transcrição do áudio no arquivo associado, por exemplo:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
Carreguei o conjunto de dados da fala em espanhol chileno no Kaggle por conveniência. Há um conjunto de dados Kaggle para gravações de falantes chilenos do sexo masculino e um conjunto de dados Kaggle para gravações de falantes chilenos do sexo feminino . Esses conjuntos de dados Kaggle serão adicionados ao Kaggle Notebook que você construirá seguindo as etapas deste guia.
WER é uma métrica que pode ser usada para medir o desempenho de modelos automáticos de reconhecimento de fala. O WER fornece um mecanismo para medir o quão próxima uma previsão de texto está de uma referência de texto. O WER faz isso registrando erros de 3 tipos:
substituições ( S
): Um erro de substituição é registrado quando a predição contém uma palavra diferente da palavra análoga na referência. Por exemplo, isso ocorre quando a previsão escreve incorretamente uma palavra na referência.
exclusões ( D
): Um erro de exclusão é registrado quando a predição contém uma palavra que não está presente na referência.
inserções ( I
): Um erro de inserção é registrado quando a predição não contém uma palavra que esteja presente na referência.
Obviamente, o WER funciona no nível da palavra. A fórmula para a métrica WER é a seguinte:
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
Um exemplo simples de WER em espanhol é o seguinte:
prediction: "Él está saliendo." reference: "Él está saltando."
Uma tabela ajuda a visualizar os erros na previsão:
TEXTO | PALAVRA 1 | PALAVRA 2 | PALAVRA 3 |
---|---|---|---|
predição | Él | está | destacando |
referência | Él | está | saltando |
| correto | correto | substituição |
A previsão contém 1 erro de substituição, 0 erros de exclusão e 0 erros de inserção. Portanto, o WER para este exemplo é:
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
Deveria ser óbvio que a taxa de erros de palavras não nos diz necessariamente quais erros específicos existem. No exemplo acima, o WER identifica que a PALAVRA 3 contém um erro no texto previsto, mas não nos diz que os caracteres i e e estão errados na previsão. Outras métricas, como a Taxa de Erros de Caracteres ("CER"), podem ser usadas para análises de erros mais precisas.
Agora você está pronto para começar a construir o notebook de ajuste fino.
Seu Kaggle Notebook deve ser configurado para enviar dados de execução de treinamento para WandB usando sua chave de API WandB. Para fazer isso, você precisa copiá-lo.
www.wandb.com
.www.wandb.ai/authorize
.
xls-r-300m-chilean-spanish-asr
.Um segredo Kaggle será usado para armazenar com segurança sua chave API WandB.
WANDB_API_KEY
no campo Rótulo e insira sua chave de API WandB para o valor.WANDB_API_KEY
esteja marcada.O conjunto de dados de fala do espanhol chileno foi carregado no Kaggle como 2 conjuntos de dados distintos:
Adicione esses dois conjuntos de dados ao seu Kaggle Notebook.
As 32 subetapas a seguir constroem cada uma das 32 células do notebook de ajuste fino em ordem.
A primeira célula do notebook de ajuste fino instala dependências. Defina a primeira célula como:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
torchaudio
para a versão mais recente. torchaudio
será usado para carregar arquivos de áudio e reamostrar dados de áudio.jiwer
que é necessário para usar o método load_metric
da biblioteca HuggingFace Datasets
usado posteriormente.As importações da segunda célula exigiam pacotes Python. Defina a segunda célula como:
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
transformers
HuggingFace e as classes Wav2Vec2*
associadas fornecem a espinha dorsal da funcionalidade usada para ajuste fino.A terceira célula importa a métrica de avaliação HuggingFace WER. Defina a terceira célula como:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
A quarta célula recupera seu segredo WANDB_API_KEY
que foi definido na Etapa 2.2 . Defina a quarta célula como:
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
A quinta célula define constantes que serão usadas em todo o notebook. Defina a quinta célula como:
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
A sexta célula define métodos utilitários para ler os arquivos de índice do conjunto de dados (consulte a subseção Conjunto de dados de treinamento acima), bem como para limpar o texto da transcrição e criar o vocabulário. Defina a sexta célula como:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
O método read_index_file_data
lê um arquivo de índice do conjunto de dados line_index.tsv
e produz uma lista de listas com nomes de arquivos de áudio e dados de transcrição, por exemplo:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
truncate_training_dataset
trunca os dados de um arquivo de índice de lista usando a constante NUM_LOAD_FROM_EACH_SET
definida na Etapa 3.5 . Especificamente, a constante NUM_LOAD_FROM_EACH_SET
é usada para especificar o número de amostras de áudio que devem ser carregadas de cada conjunto de dados. Para os fins deste guia, o número é definido como 1600
o que significa que um total de 3200
amostras de áudio serão eventualmente carregadas. Para carregar todas as amostras, defina NUM_LOAD_FROM_EACH_SET
como o valor da string all
.clean_text
é usado para retirar cada transcrição de texto dos caracteres especificados pela expressão regular atribuída a SPECIAL_CHARS
na Etapa 3.5 . Esses caracteres, inclusive a pontuação, podem ser eliminados, pois não fornecem nenhum valor semântico ao treinar o modelo para aprender mapeamentos entre recursos de áudio e transcrições de texto.create_vocab
cria um vocabulário a partir de transcrições de texto limpas. Simplesmente, ele extrai todos os caracteres exclusivos do conjunto de transcrições de texto limpas. Você verá um exemplo do vocabulário gerado na Etapa 3.14 . A sétima célula define métodos utilitários usando torchaudio
para carregar e reamostrar dados de áudio. Defina a sétima célula como:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
carrega um arquivo de áudio especificado e retorna uma matriz multidimensional torch.Tensor
dos dados de áudio junto com a taxa de amostragem do áudio. Todos os arquivos de áudio nos dados de treinamento têm uma taxa de amostragem de 48000
Hz. Esta taxa de amostragem "original" é capturada pela constante ORIG_SAMPLING_RATE
na Etapa 3.5 .resample
é usado para reduzir a resolução de dados de áudio de uma taxa de amostragem de 48000
a 16000
. wav2vec2 é pré-treinado em áudio amostrado a 16000
Hz. Conseqüentemente, qualquer áudio usado para ajuste fino deve ter a mesma taxa de amostragem. Neste caso, os exemplos de áudio devem ser reduzidos de 48000
Hz para 16000
Hz. 16000
Hz são capturados pela constante TGT_SAMPLING_RATE
na Etapa 3.5 .A oitava célula define métodos utilitários que processam os dados de áudio e transcrição. Defina a oitava célula como:
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
process_speech_audio
retorna os valores de entrada de uma amostra de treinamento fornecida.process_target_text
codifica cada transcrição de texto como uma lista de rótulos - ou seja, uma lista de índices referentes a caracteres do vocabulário. Você verá um exemplo de codificação na Etapa 3.15 .A nona célula é a célula final do método utilitário e contém o método para calcular a taxa de erro de palavra entre uma transcrição de referência e uma transcrição prevista. Defina a nona célula como:
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
A décima célula lê os arquivos de índice de dados de treinamento para as gravações de falantes do sexo masculino e as gravações de falantes do sexo feminino usando o método read_index_file_data
definido na Etapa 3.6 . Defina a décima célula como:
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
A décima primeira célula trunca as listas de dados de treinamento usando o método truncate_training_dataset
definido na Etapa 3.6 . Defina a décima primeira célula como:
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
NUM_LOAD_FROM_EACH_SET
definida na Etapa 3.5 define a quantidade de amostras a serem mantidas em cada conjunto de dados. A constante é definida como 1600
neste guia para um total de 3200
amostras.A décima segunda célula combina as listas de dados de treinamento truncadas. Defina a décima segunda célula como:
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
A décima terceira célula itera sobre cada amostra de dados de treinamento e limpa o texto de transcrição associado usando o método clean_text
definido na Etapa 3.6 . Defina a décima terceira célula como:
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
A décima quarta célula cria um vocabulário usando as transcrições limpas da etapa anterior e o método create_vocab
definido na Etapa 3.6 . Defina a décima quarta célula como:
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
O vocabulário é armazenado como um dicionário com caracteres como chaves e índices de vocabulário como valores.
Você pode imprimir vocab_dict
que deve produzir a seguinte saída:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
A décima quinta célula adiciona o caractere delimitador de palavra |
ao vocabulário. Defina a décima quinta célula como:
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
O caractere delimitador de palavra é usado ao tokenizar transcrições de texto como uma lista de rótulos. Especificamente, é usado para definir o final de uma palavra e é usado ao inicializar a classe Wav2Vec2CTCTokenizer
, como será visto na Etapa 3.17 .
Por exemplo, a lista a seguir codifica no te entiendo nada
usando o vocabulário da Etapa 3.14 :
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
A décima sexta célula despeja o vocabulário em um arquivo. Defina a décima sexta célula como:
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
Wav2Vec2CTCTokenizer
. A décima sétima célula inicializa uma instância de Wav2Vec2CTCTokenizer
. Defina a décima sétima célula como:
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
O tokenizer é usado para codificar transcrições de texto e decodificar uma lista de rótulos de volta ao texto.
Observe que o tokenizer
é inicializado com [UNK]
atribuído a unk_token
e [PAD]
atribuído a pad_token
, com o primeiro usado para representar tokens desconhecidos em transcrições de texto e o último usado para preencher transcrições ao criar lotes de transcrições com comprimentos diferentes. Esses dois valores serão adicionados ao vocabulário pelo tokenizer.
A inicialização do tokenizer nesta etapa também adicionará dois tokens adicionais ao vocabulário, nomeadamente <s>
e /</s>
que são usados para demarcar o início e o fim das frases, respectivamente.
|
é atribuído a word_delimiter_token
explicitamente nesta etapa para refletir que o símbolo de barra vertical será usado para demarcar o final das palavras de acordo com nossa adição do caractere ao vocabulário na Etapa 3.15 . O |
símbolo é o valor padrão para word_delimiter_token
. Portanto, não precisou ser definido explicitamente, mas foi feito por uma questão de clareza.
Da mesma forma que com word_delimiter_token
, um único espaço é explicitamente atribuído a replace_word_delimiter_char
refletindo que o símbolo de barra vertical |
será usado para substituir caracteres de espaço em branco nas transcrições de texto. Espaço em branco é o valor padrão para replace_word_delimiter_char
. Portanto, também não precisou ser definido explicitamente, mas foi feito por uma questão de clareza.
Você pode imprimir o vocabulário completo do tokenizer chamando o método get_vocab()
em tokenizer
.
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
A décima oitava célula inicializa uma instância de Wav2Vec2FeatureExtractor
. Defina a décima oitava célula como:
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
Wav2Vec2FeatureExtractor
são todos valores padrão, com exceção de return_attention_mask
cujo padrão é False
. Os valores padrão são mostrados/passados para maior clareza.feature_size
especifica o tamanho da dimensão dos recursos de entrada (ou seja, recursos de dados de áudio). O valor padrão deste parâmetro é 1
.sampling_rate
informa ao extrator de recursos a taxa de amostragem na qual os dados de áudio devem ser digitalizados. Conforme discutido na Etapa 3.7 , wav2vec2 é pré-treinado em áudio amostrado em 16000
Hz e, portanto, 16000
é o valor padrão para este parâmetro.padding_value
especifica o valor usado ao preencher dados de áudio, conforme necessário ao agrupar amostras de áudio de diferentes durações. O valor padrão é 0.0
.do_normalize
é usado para especificar se os dados de entrada devem ser transformados em uma distribuição normal padrão. O valor padrão é True
. A documentação da classe Wav2Vec2FeatureExtractor
observa que "[normalizar] pode ajudar a melhorar significativamente o desempenho de alguns modelos."return_attention_mask
especificam se a máscara de atenção deve ser passada ou não. O valor é definido como True
para este caso de uso. A décima nona célula inicializa uma instância de Wav2Vec2Processor
. Defina a décima nona célula como:
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
A classe Wav2Vec2Processor
combina tokenizer
e feature_extractor
da Etapa 3.17 e Etapa 3.18 respectivamente em um único processador.
Observe que a configuração do processador pode ser salva chamando o método save_pretrained
na instância da classe Wav2Vec2Processor
.
processor.save_pretrained(OUTPUT_DIR_PATH)
A vigésima célula carrega cada arquivo de áudio especificado na lista all_training_samples
. Defina a vigésima célula como:
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
torch.Tensor
e armazenados em all_input_data
como uma lista de dicionários. Cada dicionário contém os dados de áudio de uma amostra específica, juntamente com a transcrição do texto do áudio.read_audio_data
também retorna a taxa de amostragem dos dados de áudio. Como sabemos que a taxa de amostragem é de 48000
Hz para todos os arquivos de áudio neste caso de uso, a taxa de amostragem é ignorada nesta etapa.all_input_data
em um Pandas DataFrame A vigésima primeira célula converte a lista all_input_data
em um Pandas DataFrame para facilitar a manipulação dos dados. Defina a vigésima primeira célula como:
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
A vigésima segunda célula usa o processor
inicializado na Etapa 3.19 para extrair recursos de cada amostra de dados de áudio e codificar cada transcrição de texto como uma lista de rótulos. Defina a vigésima segunda célula para:
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
A vigésima terceira célula divide o DataFrame all_input_data_df
em conjuntos de dados de treinamento e avaliação (validação) usando a constante SPLIT_PCT
da Etapa 3.5 . Defina a vigésima terceira célula como:
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
SPLIT_PCT
é 0.10
neste guia, o que significa que 10% de todos os dados de entrada serão mantidos para avaliação e 90% dos dados serão usados para treinamento/ajuste.Dataset
A vigésima quarta célula converte os DataFrames train_data_df
e valid_data_df
em objetos Dataset
. Defina a vigésima quarta célula para:
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Os objetos Dataset
são consumidos pelas instâncias da classe HuggingFace Trainer
, como você verá na Etapa 3.30 .
Esses objetos contêm metadados sobre o conjunto de dados, bem como sobre o próprio conjunto de dados.
Você pode imprimir train_data
e valid_data
para visualizar os metadados de ambos os objetos Dataset
.
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
A vigésima quinta célula inicializa o modelo XLS-R (0,3) pré-treinado. Defina a vigésima quinta célula para:
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
from_pretrained
chamado em Wav2Vec2ForCTC
especifica que queremos carregar os pesos pré-treinados para o modelo especificado.MODEL
foi especificada na Etapa 3.5 e definida como facebook/wav2vec2-xls-r-300m
refletindo o modelo XLS-R (0,3).ctc_loss_reduction
especifica o tipo de redução a ser aplicada à saída da função de perda da Classificação Temporal Conexionista ("CTC"). A perda CTC é usada para calcular a perda entre uma entrada contínua, neste caso dados de áudio, e uma sequência alvo, neste caso transcrições de texto. Ao definir o valor como mean
, as perdas de saída para um lote de entradas serão divididas pelos comprimentos alvo. A média do lote é então calculada e a redução é aplicada aos valores de perda.pad_token_id
especifica o token a ser usado para preenchimento durante o lote. Ele é definido como o ID [PAD]
definido ao inicializar o tokenizer na Etapa 3.17 .vocab_size
define o tamanho do vocabulário do modelo. É o tamanho do vocabulário após a inicialização do tokenizer na Etapa 3.17 e reflete o número de nós da camada de saída da parte direta da rede.A vigésima sexta célula congela os pesos pré-treinados do extrator de recursos. Defina a vigésima sexta célula como:
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
A vigésima sétima célula inicializa os argumentos de treinamento que serão passados para uma instância Trainer
. Defina a vigésima sétima célula para:
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
TrainingArguments
aceita mais de 100 parâmetros .save_safetensors
quando False
especifica que o modelo ajustado deve ser salvo em um arquivo pickle
em vez de usar o formato safetensors
.group_by_length
quando True
indica que amostras de aproximadamente o mesmo comprimento devem ser agrupadas. Isso minimiza o preenchimento e melhora a eficiência do treinamento.per_device_train_batch_size
define o número de amostras por minilote de treinamento. Este parâmetro é definido como 18
por meio da constante TRAIN_BATCH_SIZE
atribuída na Etapa 3.5 . Isso implica 160 passos por época.per_device_eval_batch_size
define o número de amostras por minilote de avaliação (holdout). Este parâmetro é definido como 10
por meio da constante EVAL_BATCH_SIZE
atribuída na Etapa 3.5 .num_train_epochs
define o número de épocas de treinamento. Este parâmetro é definido como 30
através da constante TRAIN_EPOCHS
atribuída na Etapa 3.5 . Isso implica um total de 4.800 passos durante o treinamento.gradient_checkpointing
quando True
ajuda a economizar memória ao verificar cálculos de gradiente, mas resulta em passagens reversas mais lentas.evaluation_strategy
quando definido como steps
significa que a avaliação será realizada e registrada durante o treinamento em um intervalo especificado pelo parâmetro eval_steps
.logging_strategy
quando definido como steps
significa que as estatísticas da execução de treinamento serão registradas em um intervalo especificado pelo parâmetro logging_steps
.save_strategy
quando definido como steps
significa que um ponto de verificação do modelo ajustado será salvo em um intervalo especificado pelo parâmetro save_steps
.eval_steps
define o número de etapas entre avaliações de dados de validação. Este parâmetro é definido como 100
através da constante EVAL_STEPS
atribuída na Etapa 3.5 .save_steps
define o número de etapas após as quais um ponto de verificação do modelo ajustado é salvo. Este parâmetro é definido como 3200
através da constante SAVE_STEPS
atribuída na Etapa 3.5 .logging_steps
define o número de etapas entre os logs das estatísticas da execução de treinamento. Este parâmetro é definido como 100
através da constante LOGGING_STEPS
atribuída na Etapa 3.5 .learning_rate
define a taxa de aprendizagem inicial. Este parâmetro é definido como 1e-4
por meio da constante LEARNING_RATE
atribuída na Etapa 3.5 .warmup_steps
define o número de etapas para aquecer linearmente a taxa de aprendizado de 0 até o valor definido por learning_rate
. Este parâmetro é definido como 800
através da constante WARMUP_STEPS
atribuída na Etapa 3.5 .A vigésima oitava célula define a lógica para preencher dinamicamente as sequências de entrada e de destino. Defina a vigésima oitava célula como:
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
Trainer
que será inicializada momentaneamente na Etapa 3.30 . Como as sequências de entrada e as sequências de rótulos variam em comprimento em cada minilote, algumas sequências devem ser preenchidas para que tenham todas o mesmo comprimento.DataCollatorCTCWithPadding
preenche dinamicamente os dados do minilote. O parâmetro padding
, quando definido como True
, especifica que sequências mais curtas de recursos de entrada de áudio e sequências de rótulos devem ter o mesmo comprimento que a sequência mais longa em um minilote.0.0
definido ao inicializar o extrator de recursos na Etapa 3.18 .-100
para que esses rótulos sejam ignorados no cálculo da métrica WER.A vigésima nona célula inicializa uma instância do agrupamento de dados definido na etapa anterior. Defina a vigésima nona célula para:
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
A trigésima célula inicializa uma instância da classe Trainer
. Defina a trigésima célula como:
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
Trainer
é inicializada com:model
pré-treinado foi inicializado na Etapa 3.25 .train_data
Dataset
da Etapa 3.24 .valid_data
Dataset
da Etapa 3.24 .tokenizer
é atribuído a processor.feature_extractor
e funciona com data_collator
para preencher automaticamente as entradas para a entrada de comprimento máximo de cada minilote. A trigésima primeira célula chama o método train
na instância da classe Trainer
para ajustar o modelo. Defina a trigésima primeira célula como:
### CELL 31: Finetune the model ### trainer.train()
A trigésima segunda célula é a última célula do notebook. Ele salva o modelo ajustado chamando o método save_model
na instância Trainer
. Defina a célula do trigésimo segundo como:
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
Agora que todas as células do notebook foram construídas, é hora de começar os ajustes finos.
Configure o Kaggle Notebook para rodar com o acelerador NVIDIA GPU P100 .
Envie o notebook no Kaggle.
Monitore os dados da corrida de treinamento fazendo login em sua conta WandB e localizando a corrida associada.
O treinamento de mais de 30 épocas deve levar cerca de 5 horas usando o acelerador NVIDIA GPU P100. O WER nos dados de validação deve cair para aproximadamente 0,15 no final do treinamento. Não é exatamente um resultado de última geração, mas o modelo ajustado ainda é suficientemente útil para muitas aplicações.
O modelo ajustado será enviado para o diretório Kaggle especificado pela constante OUTPUT_DIR_PATH
especificada na Etapa 3.5 . A saída do modelo deve incluir os seguintes arquivos:
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
Esses arquivos podem ser baixados localmente. Além disso, você pode criar um novo modelo Kaggle usando os arquivos de modelo. O modelo Kaggle será usado com o guia de inferência complementar para executar inferência no modelo ajustado.
Parabéns pelo ajuste fino do wav2vec2 XLS-R! Lembre-se de que você pode usar estas etapas gerais para ajustar o modelo em outros idiomas que desejar. Executar inferência no modelo ajustado gerado neste guia é bastante simples. As etapas de inferência serão descritas em um guia separado deste. Pesquise meu nome de usuário HackerNoon para encontrar o guia complementar.