这是使用 wav2vec2 第 1 部分 - 微调 XLS-R 以进行自动语音识别(“第 1 部分指南”)的配套指南。我编写了第 1 部分指南,介绍如何针对智利西班牙语微调 Meta AI 的wav2vec2 XLS-R(“XLS-R”)模型。假设您已完成该指南并生成了您自己的微调 XLS-R 模型。本指南将解释通过Kaggle Notebook对微调后的 XLS-R 模型进行推理的步骤。
要完成本指南,您需要具备:
spanish-asr-inference
。本指南使用秘鲁西班牙语语音数据集作为测试数据来源。与智利西班牙语语音数据集一样,秘鲁语者数据集也由两个子数据集组成:2,918 条秘鲁男性说话者的录音和 2,529 条秘鲁女性说话者的录音。
该数据集已作为 2 个不同的数据集上传至 Kaggle:
单击添加输入,将这两个数据集都添加到您的 Kaggle Notebook 中。
您应该在使用 wav2vec2 第 1 部分 - 对 XLS-R 进行自动语音识别微调指南的第 4 步中将微调后的模型保存为Kaggle 模型。
单击“添加输入” ,将微调后的模型添加到你的 Kaggle Notebook 中。
以下 16 个子步骤按顺序构建推理笔记本的 16 个单元。您会注意到,这里使用了许多与第 1 部分指南相同的实用方法。
推理笔记本的第一个单元格安装依赖项。将第一个单元格设置为:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
第二个单元格导入所需的 Python 包。将第二个单元格设置为:
### CELL 2: Import Python packages ### import re import math import random import pandas as pd import torchaudio from datasets import load_metric from transformers import pipeline
第三个单元格导入 HuggingFace WER 评估指标。将第三个单元格设置为:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
第四个单元格设置将在整个笔记本中使用的常量。将第四个单元格设置为:
### CELL 4: Constants ### # Testing data TEST_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-peru-male/" TEST_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-peru-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 3 # Special characters SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000
第五个单元格定义了用于读取数据集索引文件以及清理转录文本和从测试数据中生成一组随机样本的实用方法。将第五个单元格设置为:
### CELL 5: Utility methods for reading index files, cleaning text, random indices generator ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def get_random_samples(dataset: list, num: int) -> list: used = [] samples = [] for i in range(num): a = -1 while a == -1 or a in used: a = math.floor(len(dataset) * random.random()) samples.append(dataset[a]) used.append(a) return samples
read_index_file_data
方法读取line_index.tsv
数据集索引文件并生成包含音频文件名和转录数据的列表列表,例如:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
clean_text
方法用于删除每个文本转录中步骤 2.4中分配给SPECIAL_CHARS
的正则表达式所指定的字符。这些字符(包括标点符号)可以被删除,因为它们在训练模型以学习音频特征和文本转录之间的映射时不提供任何语义价值。get_random_samples
方法返回一组随机测试样本,其数量由步骤 2.4中的常量NUM_LOAD_FROM_EACH_SET
设置。第六个单元格定义使用torchaudio
加载和重新采样音频数据的实用方法。将第六个单元格设置为:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
方法加载指定的音频文件并返回音频数据的torch.Tensor
多维矩阵以及音频的采样率。训练数据中的所有音频文件的采样率均为48000
Hz。此“原始”采样率由步骤 2.4中的常量ORIG_SAMPLING_RATE
捕获。resample
方法用于将音频数据从采样率48000
下采样到目标采样率16000
。第七个单元格使用步骤 2.5中定义的read_index_file_data
方法读取男性说话者录音和女性说话者录音的测试数据索引文件。将第七个单元格设置为:
### CELL 7: Read test data ### test_data_male = read_index_file_data(TEST_DATA_PATH_MALE, "line_index.tsv") test_data_female = read_index_file_data(TEST_DATA_PATH_FEMALE, "line_index.tsv")
第八个单元格使用步骤 2.5中定义的get_random_samples
方法生成随机测试样本集。将第八个单元格设置为:
### CELL 8: Generate lists of random test samples ### random_test_samples_male = get_random_samples(test_data_male, NUM_LOAD_FROM_EACH_SET) random_test_samples_female = get_random_samples(test_data_female, NUM_LOAD_FROM_EACH_SET)
第九个单元格将男性测试样本和女性测试样本合并为一个列表。将第九个单元格设置为:
### CELL 9: Combine test data ### all_test_samples = random_test_samples_male + random_test_samples_female
第十个单元格对每个测试数据样本进行迭代,并使用步骤 2.5中定义的clean_text
方法清理相关的转录文本。将第十个单元格设置为:
### CELL 10: Clean text transcriptions ### for index in range(len(all_test_samples)): all_test_samples[index][1] = clean_text(all_test_samples[index][1])
第 11 个单元格加载all_test_samples
列表中指定的每个音频文件。将第 11 个单元格设置为:
### CELL 11: Load audio data ### all_test_data = [] for index in range(len(all_test_samples)): speech_array, sampling_rate = read_audio_data(all_test_samples[index][0]) all_test_data.append({ "raw": speech_array, "sampling_rate": sampling_rate, "target_text": all_test_samples[index][1] })
torch.Tensor
的形式返回,并以字典列表的形式存储在all_test_data
中。每个字典包含特定样本的音频数据、采样率和音频的文本转录。第十二个单元格将音频数据重新采样为目标采样率16000
。将第十二个单元格设置为:
### CELL 12: Resample audio data and cast to NumPy arrays ### all_test_data = [{"raw": resample(sample["raw"]).numpy(), "sampling_rate": TGT_SAMPLING_RATE, "target_text": sample["target_text"]} for sample in all_test_data]
第十三个单元格初始化 HuggingFace transformer
库pipeline
类的实例。将第十三个单元格设置为:
### CELL 13: Initialize instance of Automatic Speech Recognition Pipeline ### transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")
model
参数必须设置为步骤 1.3中添加到 Kaggle Notebook 的微调模型的路径,例如:
transcriber = pipeline("automatic-speech-recognition", model = "/kaggle/input/xls-r-300m-chilean-spanish/transformers/hardy-pine/1")
第 14 个单元格在测试数据上调用上一步初始化的transcriber
来生成文本预测。将第 14 个单元格设置为:
### CELL 14: Generate transcriptions ### transcriptions = transcriber(all_test_data)
第十五个单元格计算每个预测的 WER 分数以及所有预测的总体 WER 分数。将第十五个单元格设置为:
### CELL 15: Calculate WER metrics ### predictions = [transcription["text"] for transcription in transcriptions] references = [transcription["target_text"][0] for transcription in transcriptions] wers = [] for p in range(len(predictions)): wer = wer_metric.compute(predictions = [predictions[p]], references = [references[p]]) wers.append(wer) zipped = list(zip(predictions, references, wers)) df = pd.DataFrame(zipped, columns=["Prediction", "Reference", "WER"]) wer = wer_metric.compute(predictions = predictions, references = references)
第十六个单元格(也是最后一个单元格)仅打印上一步中的 WER 计算结果。将第十六个单元格设置为:
### CELL 16: Output WER metrics ### pd.set_option("display.max_colwidth", None) print(f"Overall WER: {wer}") print(df)
由于笔记本会根据测试数据的随机样本生成预测,因此每次运行笔记本时输出都会有所不同。在运行笔记本时, NUM_LOAD_FROM_EACH_SET
设置为3
,总共 6 个测试样本,生成了以下输出:
Overall WER: 0.013888888888888888 Prediction \ 0 quiero que me reserves el mejor asiento del teatro 1 el llano en llamas es un clásico de juan rulfo 2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera 3 hay tres cafés que están abiertos hasta las once de la noche 4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras 5 cuántos albergues se abrieron después del terremoto del diecinueve de setiembre Reference \ 0 quiero que me reserves el mejor asiento del teatro 1 el llano en llamas es un clásico de juan rulfo 2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera 3 hay tres cafés que están abiertos hasta las once de la noche 4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras 5 cuántos albergues se abrieron después del terremoto del diecinueve de septiembre WER 0 0.000000 1 0.000000 2 0.000000 3 0.000000 4 0.000000 5 0.090909
可以看出,该模型表现非常出色!它只在第六个样本(索引5
)中犯了一个错误,将单词septiembre
拼写为setiembre
。当然,使用不同的测试样本(更重要的是,使用更多的测试样本)再次运行该笔记本,将产生不同且更具信息量的结果。尽管如此,这些有限的数据表明该模型可以在不同的西班牙语方言上表现良好 - 例如,它是用智利西班牙语训练的,但似乎在秘鲁西班牙语上表现良好。
如果您刚刚开始学习如何使用 wav2vec2 模型,我希望《使用 wav2vec2 第 1 部分 - 微调 XLS-R 进行自动语音识别》指南和本指南对您有所帮助。如前所述,第 1 部分指南生成的微调模型并不是最先进的,但仍可用于许多应用。祝您构建愉快!